diff --git a/.gitignore b/.gitignore index bdcb067fc26d2a18ed88034ab616c08095794e17..fdc61ee8251a63953e5e92cff602e7ace9653700 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,6 @@ node_modules /.tf_configure.bazelrc /bazel-* /bazel_pip -/third_party/eigen3/mkl_include -/third_party/mkl/* /tools/python_bin_path.sh /tools/git/gen /pip_test diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..69393c377589cc707d6c079e575346564b9c3fbf --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,52 @@ +# Where component owners are known, add them here. + +tensorflow/core/platform/windows/* @mrry +tensorflow/java/* @asimshankar +tensorflow/tensorboard/* @jart @dandelionmane +tensorflow/tools/docs/* @markdaoust + +# contrib + +# NEED OWNER: tensorflow/contrib/avro/* +tensorflow/contrib/batching/* @alextp @chrisolston +tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon +tensorflow/contrib/cmake/* @mrry @benoitsteiner +tensorflow/contrib/copy_graph/* @tucker @poxvoculi +tensorflow/contrib/crf/* @kentonl +tensorflow/contrib/data/* @mrry +tensorflow/contrib/distributions/* @jvdillon @langmore @rsepassi +tensorflow/contrib/factorization/* @agarwal-ashish @xavigonzalvo +tensorflow/contrib/ffmpeg/* @fredbertsch +# NEED OWNERT: tensorflow/contrib/framework/* +tensorflow/contrib/graph_editor/* @purpledog +# NEED OWNER: tensorflow/contrib/grid_rnn/* +tensorflow/contrib/hvx/* @satok16 +tensorflow/contrib/imperative/* @keveman +tensorflow/contrib/integrate/* @shoyer +tensorflow/contrib/kernel_methods/* @petrosmol +tensorflow/contrib/ios_examples/* @petewarden +tensorflow/contrib/labeled_tensor/* @shoyer +tensorflow/contrib/layers/* @fchollet @martinwicke +tensorflow/contrib/learn/* @martinwicke @ispirmustafa @alextp +tensorflow/contrib/linalg/* @langmore +tensorflow/contrib/linear_optimizer/* @petrosmol @andreasst @katsiapis +tensorflow/contrib/lookup/* @ysuematsu @andreasst +tensorflow/contrib/losses/* @alextp @ispirmustafa +tensorflow/contrib/makefile/* @petewarden @satok16 @wolffg +tensorflow/contrib/metrics/* @alextp @honkentuber @ispirmustafa +tensorflow/contrib/nccl/* @cwhipkey @zheng-xq +tensorflow/contrib/opt/* @strategist333 +tensorflow/contrib/pi_examples/* @maciekcc +tensorflow/contrib/quantization/* @petewarden @cwhipkey @keveman +tensorflow/contrib/rnn/* @ebrevdo +tensorflow/contrib/saved_model/* @nfiedel @sukritiramesh +tensorflow/contrib/seq2seq/* @lukaszkaiser +tensorflow/contrib/session_bundle/* @nfiedel @sukritiramesh +tensorflow/contrib/slim/* @sguada @thenbasilmanran +tensorflow/contrib/stateless/* @girving +tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst +tensorflow/contrib/testing/* @dandelionmane +tensorflow/contrib/timeseries/* @allenlavoie +tensorflow/contrib/tpu/* @frankchn @saeta @jhseu +tensorflow/contrib/training/* @joel-shor @ebrevdo +tensorflow/contrib/util/* @sherrym diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..10fd595fec7f240c3fdc871e1f32cc83f2ffd46d --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,70 @@ +# TensorFlow Code of Conduct + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. + + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic address, without explicit permission +* Conduct which could reasonably be considered inappropriate for the forum in which it occurs. + +All TensorFlow forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable. + + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. + + +## Scope + +This Code of Conduct applies to all content on tensorflow.org, TensorFlow’s GitHub organization, or any other official TensorFlow web presence allowing for community interactions, as well as at all official TensorFlow events, whether offline or online. + +The Code of Conduct also applies within project spaces and in public spaces whenever an individual is representing TensorFlow or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed or de facto representative at an online or offline event. + + +## Conflict Resolution + +Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between. + +If the behaviour is threatening or harassing, or for other reasons requires immediate escalation, please see below. + +However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute. + +If you are experiencing or witnessing conflict, we ask you to use the following escalation strategy to address the conflict: + +1. Address the perceived conflict directly with those involved, preferably in a real-time medium. +2. If this fails, get a third party (e.g. a mutual friend, and/or someone with background on the issue, but not involved in conflict) to intercede. +3. If you are still unable to resolve the conflict, and you believe it rises to harassment or another code of conduct violation, report it. + + +## Reporting Violations + +Violations of the Code of Conduct can be reported to TensorFlow’s Project Steward at conduct@tensorflow.org. The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. + +Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report. + + +## Enforcement + +If the Project Steward receives a report alleging a violation of the Code of Conduct, the Project Steward will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Steward will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Steward may issue sanctions without notice. + + +## Attribution + +This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at http://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct. diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 6f4c048ce83fb47a611b5dfe08e0fde0779994c0..2bf2c754cf64ec3bac22a22fbafcebbd4dc54bf4 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -1,11 +1,12 @@ Please go to Stack Overflow for help and support: -http://stackoverflow.com/questions/tagged/tensorflow +https://stackoverflow.com/questions/tagged/tensorflow If you open a GitHub issue, here is our policy: 1. It must be a bug or a feature request. 2. The form below must be filled out. +3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues). **Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow. @@ -16,6 +17,7 @@ If you open a GitHub issue, here is our policy: - **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: - **TensorFlow installed from (source or binary)**: - **TensorFlow version (use command below)**: +- **Python version**: - **Bazel version (if compiling from source)**: - **CUDA/cuDNN version**: - **GPU model and memory**: diff --git a/README.md b/README.md index e7dbf57b25a6276498ce26f1df41e2a54d1fc159..4e17182f8117f86fbdfc96dd0926804fda0c310d 100644 --- a/README.md +++ b/README.md @@ -9,37 +9,38 @@ | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | **TensorFlow** is an open source software library for numerical computation using -data flow graphs. Nodes in the graph represent mathematical operations, while +data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow between them. This flexible architecture lets you deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting code. TensorFlow also includes TensorBoard, a data visualization toolkit. TensorFlow was originally developed by researchers and engineers -working on the Google Brain team within Google's Machine Intelligence research +working on the Google Brain team within Google's Machine Intelligence Research organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. -**If you'd like to contribute to TensorFlow, be sure to review the [contribution +**If you want to contribute to TensorFlow, be sure to review the [contribution guidelines](CONTRIBUTING.md).** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs, but please see -[Community](https://www.tensorflow.org/community/) for general questions -and discussion.** +tracking requests and bugs. So please see +[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions +and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** ## Installation -*See [Installing TensorFlow](https://www.tensorflow.org/install/) for instructions on how to install our release binaries or how to build from source.* +*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.* People who are a little more adventurous can also try our nightly binaries: -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) -* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.0rc2-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.0rc2-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/)) + +* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) +* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) +* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) +* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) +* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/)) +* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/)) * Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) @@ -55,16 +56,17 @@ $ python 'Hello, TensorFlow!' >>> a = tf.constant(10) >>> b = tf.constant(32) ->>> sess.run(a+b) +>>> sess.run(a + b) 42 >>> ``` ## For more information -* [TensorFlow website](https://tensorflow.org) -* [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf) +* [TensorFlow website](https://www.tensorflow.org) +* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) +* [TensorFlow course at Stanford](https://web.stanford.edu/class/cs20si) -The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/about/#community) for an incomplete list. +Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/RELEASE.md b/RELEASE.md index d22c5c62fe01e5d3e2bc0cd4657aff692ee734bf..e7c086164a261b1169446738abe3b5e390d2d798 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,112 @@ +# Release 1.3.0 + +## Major Features and Improvements +* Added canned estimators to Tensorflow library. List of added estimators: `DNNClassifier`, `DNNRegressor`, `LinearClassifer`, `LinearRegressor`, `DNNLinearCombinedClassifier`, `DNNLinearCombinedRegressor`. +* All our prebuilt binaries have been built with cuDNN 6. +* Adds a file cache to the GCS filesystem with configurable max staleness for file contents. This permits caching of file contents across close/open boundaries. +* Added an axis parameter to `tf.gather`. +* Added a `constant_values` keyword argument to `tf.pad`. +* Adds `Dataset.interleave` transformation. +* Add `ConcatenateDataset` to concatenate two datasets. +* Added Mobilenet support to TensorFlow for Poets training script. +* Adds a block cache to the GCS filesystem with configurable block size and count. +* SinhArcSinh bijector added. +* Added `Dataset.list_files` API. +* Introduces new operations and Python bindings for the Cloud TPU. +* Adding TensorFlow-iOS CocoaPod for symmetry with tensorflow-android. +* Introduces base implementations of ClusterResolvers. +* Unify memory representations of TensorShape and PartialTensorShape. As a consequence, tensors now have a maximum of 254 dimensions, not 255. +* Changed references to LIBXSMM to use version 1.8.1. +* TensorFlow Debugger (tfdbg): Display summaries of numeric tensor values with the `-s` flag to command `print_tensor` or `pt`. +* Initial release of the statistical distribution library `tf.distributions`. +* GPU kernels and speed improvements for for unary `tf.where` and `tf.nn.top_k`. +* Monotonic Attention wrappers added to `tf.contrib.seq2seq`. + +## Breaking Changes to the API +* `tf.RewriterConfig` was removed from the Python API after being available in 1.2 release candidates (it was never in an actual release). Graph rewriting is still available, just not as `tf.RewriterConfig`. Instead add an explicit import. +* Breaking change to `tf.contrib.data.Dataset` APIs that expect a nested structure. Lists are now converted to `tf.Tensor` implicitly. You may need to change uses of lists to tuples in existing code. In addition, dicts are now supported as a nested structure. + +## Changes to contrib APIs +* Adds tf.contrib.nn.rank_sampled_softmax_loss, a sampled-softmax variant that can improve rank loss. +* `tf.contrib.metrics`.{streaming_covariance,streaming_pearson_correlation} modified to return nan when they have seen less or equal to 1 unit of weight. +* Adds time series models to contrib. See contrib/timeseries/README.md for details. +* Adds FULLY_CONNECTED Op to tensorflow/contrib/lite/schema.fbs + +## Bug Fixes and Other Changes +* Fixes 'strides' and 'begin' dtype mismatch when slicing using int64 Tensor index in python. +* Improved convolution padding documentation. +* Add a tag constant, gpu, to present graph with GPU support. +* `saved_model.utils` now support SparseTensors transparently. +* A more efficient implementation of non-max suppression. +* Add support for the shrinkage-type L2 to FtrlOptimizer in addition to the online L2 it already supports. +* Fix negative variance in moments calculation. +* Expand UniqueOp Benchmark Tests to cover more collision cases. +* Improves stability of GCS filesystem on Mac. +* Add time estimation to HloCostAnalysis. +* Fixed the bug in Estimator that params in constructor was not a deepcopy of the user provided one. This bugs inadvertently enabled user to mutate the params after the creation of Estimator, leading to potentially undefined behavior. +* Added None check for save_path in `saver.restore`. +* Register devices under their legacy names in device_mgr to ease the transition to clusterspec-propagated configurations. +* VectorExponential added to distributions. +* Add a bitwise module with bitwise_and, bitwise_or, bitwise_xor, and invert functions. +* Add fixed-grid ODE integration routines. +* Allow passing bounds to ScipyOptimizerInterface. +* Correctness fixes for fft_length parameter to `tf.spectral.rfft` & `tf.spectral.irfft`. +* Exported model signatures using the 'predict' method will no longer have their input and output keys silently ignored and rewritten to 'inputs' and 'outputs'. If a model was exported with different names before 1.2, and is now served with tensorflow/serving, it will accept requests using 'inputs' and 'outputs'. Starting at 1.2, such a model will accept the keys specified during export. Therefore, inference requests using 'inputs' and 'outputs' may start to fail. To fix this, either update any inference clients to send requests with the actual input and output keys used by the trainer code, or conversely, update the trainer code to name the input and output Tensors 'inputs' and 'outputs', respectively. Signatures using the 'classify' and 'regress' methods are not affected by this change; they will continue to standardize their input and output keys as before. +* Add in-memory caching to the Dataset API. +* Set default end_of_sequence variable in datasets iterators to false. +* [Performance] Increase performance of `tf.layers.con2d` when setting use_bias=True by 2x by using nn.bias_add. +* Update iOS examples to use CocoaPods, and moved to tensorflow/examples/ios. +* Adds a family= attribute in `tf.summary` ops to allow controlling the tab name used in Tensorboard for organizing summaries. +* When GPU is configured, do not require --config=cuda, instead, automatically build for GPU if this is requested in the configure script. +* Fix incorrect sampling of small probabilities in CPU/GPU multinomial. +* Add a list_devices() API on sessions to list devices within a cluster. Additionally, this change augment the ListDevices master API to support specifying a session. +* Allow uses of over-parameterized separable convolution. +* TensorForest multi-regression bug fix. +* Framework now supports armv7, cocoapods.org now displays correct page. +* Script to create iOS framework for CocoaPods. +* Android releases of TensorFlow are now pushed to jcenter for easier integration into apps. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md for more details. +* Fixed a bug that prevented tfdbg from functioning with multi-GPU setups. +* Fixed a bug that prevented tfdbg from working with `tf.Session.make_callable`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4F2E4A2E, Adriano Carmezim, Adrià Arrufat, Alan Yee, Alex Lattas, Alex Rothberg, +Alexandr Baranezky, Ali Siddiqui, Andreas Solleder, Andrei Costinescu, Andrew Hundt, +Androbin, Andy Kernahan, Anish Shah, Anthony Platanios, Arvinds-Ds, b1rd, Baptiste +Arnaud, Ben Mabey, Benedikt Linse, Beomsu Kim, Bo Wang, Boyuan Deng, Brett Koonce, +Bruno Rosa, Carl Thomé, Changming Sun, Chase Roberts, Chirag Bhatia, Chris Antaki, +Chris Hoyean Song, Chris Tava, Christos Nikolaou, Croath Liu, cxx, Czxck001, Daniel +Ylitalo, Danny Goodman, Darren Garvey, David Brailovsky, David Norman, DavidNorman, +davidpham87, ddurham2, Dhruv, DimanNe, Drew Hintz, Dustin Tran, Earthson Lu, ethiraj, +Fabian Winnen, Fei Sun, Freedom" Koan-Sin Tan, Fritz Obermeyer, Gao, Xiang, Gautam, +Guenther Schmuelling, Gyu-Ho Lee, Hauke Brammer, horance, Humanity123, J Alammar, +Jayeol Chun, Jeroen BéDorf, Jianfei Wang, jiefangxuanyan, Jing Jun Yin, Joan Puigcerver, +Joel Hestness, Johannes Mayer, John Lawson, Johnson145, Jon Malmaud, Jonathan Alvarez-Gutierrez, +Juang, Yi-Lin, Julian Viereck, Kaarthik Sivashanmugam, Karl Lessard, karl@kubx.ca, Kevin +Carbone, Kevin Van Der Burgt, Kongsea, ksellesk, lanhin, Lef Ioannidis, Liangliang He, +Louis Tiao, Luke Iwanski, LáSzló Csomor, magixsno, Mahmoud Abuzaina, Marcel Hlopko, Mark +Neumann, Maxwell Paul Brickner, mdfaijul, MichaëL Defferrard, Michał JastrzęBski, Michele +Colombo, Mike Brodie, Mosnoi Ion, mouradmourafiq, myPrecious, Nayana Thorat, +Neeraj Kashyap, Nelson Liu, Niranjan Hasabnis, Olivier Moindrot, orome, Pankaj Gupta, Paul +Van Eck, peeyush18, Peng Yu, Pierre, preciousdp11, qjivy, Raingo, raoqiyu, ribx, Richard S. +Imaoka, Rishabh Patel, Robert Walecki, Rockford Wei, Ryan Kung, Sahil Dua, Sandip Giri, Sayed +Hadi Hashemi, sgt101, Shitian Ni, Shuolongbj, Siim PõDer, Simon Perkins, sj6077, SOLARIS, +Spotlight0xff, Steffen Eberbach, Stephen Fox, superryanguo, Sven Mayer, Tapan Prakash, +Tiago Morais Morgado, Till Hoffmann, Tj Rana, Vadim Markovtsev, vhasanov, Wei Wu, +windead, Yan (Asta) Li, Yan Chen, Yann Henon, Yi Wang, Yong Tang, yorkie, Yuan (Terry) +Tang, Yuxin Wu, zhengjiajin, zhongzyd, 黄璞 + +We are also grateful to all who filed issues or helped resolve them, asked and +answered questions, and were part of inspiring discussions. + +# Release 1.2.1 + +## Bug Fixes and Other Changes +* Updating markdown version required to >= 2.6.8. +* Support tensors as dropout rates again, by removing the min(max(..)) + # Release 1.2.0 ## Major Features and Improvements @@ -59,37 +168,6 @@ integration into apps. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md for more details. -* RNNCells' variable names have been renamed for consistency with Keras layers. - Specifically, the previous variable names "weights" and "biases" have - been changed to "kernel" and "bias", respectively. - This may cause backward incompatibility with regard to your old - checkpoints containing such RNN cells, in which case you can use the tool - [checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py) - to convert the variable names in your old checkpoints. -* Many of the RNN functions and classes that were in the `tf.nn` namespace - before the 1.0 release and which were moved to `tf.contrib.rnn` have now - been moved back to the core namespace. This includes - `RNNCell`, `LSTMCell`, `GRUCell`, and a number of other cells. These - now reside in `tf.nn.rnn_cell` (with aliases in `tf.contrib.rnn` for backwards - compatibility). The original `tf.nn.rnn` function is now `tf.nn.static_rnn`, - and the bidirectional static and state saving static rnn functions are also - now back in the `tf.nn` namespace. - - Notable exceptions are the `EmbeddingWrapper`, `InputProjectionWrapper` and - `OutputProjectionWrapper`, which will slowly be moved to deprecation - in `tf.contrib.rnn`. These are inefficient wrappers that should often - be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post- - processing of the rnn. For RNN decoding, this functionality has been replaced - with an alternative API in `tf.contrib.seq2seq`. -* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of - optimized deep learning primitives: In addition to matrix multiplication and - convolution, these building blocks include: - Direct batched convolution - Pooling: maximum, minimum, average - Normalization: LRN, batch normalization - Activation: rectified linear unit (ReLU) - Data manipulation: multi-dimensional transposition (conversion), split, - concat, sum and scale. ## Deprecations @@ -113,6 +191,8 @@ checkpoints containing such RNN cells, in which case you can use the [checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py) to convert the variable names in your old checkpoints. +* Added `tf.contrib.kernel_methods` module with Ops and estimators for primal + (explicit) kernel methods in TensorFlow. ## Bug Fixes and Other Changes * In python, `Operation.get_attr` on type attributes returns the Python DType diff --git a/WORKSPACE b/WORKSPACE index 74ce13f4e88710050ac3f5aa22e6de0375da9694..6b5d24560ca416bcff10355cf760e6c4af928137 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -32,6 +32,9 @@ load("//tensorflow:workspace.bzl", "tf_workspace") # name="androidndk", # path="", # # This needs to be 14 or higher to compile TensorFlow. +# # Please specify API level to >= 21 to build for 64-bit +# # archtectures or the Android NDK will automatically select biggest +# # API level that it supports without notice. # # Note that the NDK version is not the API level. # api_level=14) diff --git a/arm_compiler.BUILD b/arm_compiler.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b231d0180e3953e72c350a53544419ae634a355a --- /dev/null +++ b/arm_compiler.BUILD @@ -0,0 +1,81 @@ +package(default_visibility = ['//visibility:public']) + +filegroup( + name = 'gcc', + srcs = [ + 'bin/arm-linux-gnueabihf-gcc', + ], +) + +filegroup( + name = 'ar', + srcs = [ + 'bin/arm-linux-gnueabihf-ar', + ], +) + +filegroup( + name = 'ld', + srcs = [ + 'bin/arm-linux-gnueabihf-ld', + ], +) + +filegroup( + name = 'nm', + srcs = [ + 'bin/arm-linux-gnueabihf-nm', + ], +) + +filegroup( + name = 'objcopy', + srcs = [ + 'bin/arm-linux-gnueabihf-objcopy', + ], +) + +filegroup( + name = 'objdump', + srcs = [ + 'bin/arm-linux-gnueabihf-objdump', + ], +) + +filegroup( + name = 'strip', + srcs = [ + 'bin/arm-linux-gnueabihf-strip', + ], +) + +filegroup( + name = 'as', + srcs = [ + 'bin/arm-linux-gnueabihf-as', + ], +) + +filegroup( + name = 'compiler_pieces', + srcs = glob([ + 'arm-linux-gnueabihf/**', + 'libexec/**', + 'lib/gcc/arm-linux-gnueabihf/**', + 'include/**', + ]), +) + +filegroup( + name = 'compiler_components', + srcs = [ + ':gcc', + ':ar', + ':ld', + ':nm', + ':objcopy', + ':objdump', + ':strip', + ':as', + ], +) diff --git a/configure b/configure index 602124225fe0712135798a779e509a16fe2ccc79..9c21d2b03a27714f05094667691e74c16fa89f35 100755 --- a/configure +++ b/configure @@ -3,879 +3,12 @@ set -e set -o pipefail -MIN_BAZEL_VERSION=0.4.5 - -# Find out the absolute path to where ./configure resides -pushd `dirname $0` > /dev/null -SOURCE_BASE_DIR=`pwd -P` -popd > /dev/null - -PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" - -function is_linux() { - [[ "${PLATFORM}" == "linux" ]] -} - -function is_macos() { - [[ "${PLATFORM}" == "darwin" ]] -} - -function is_windows() { - # On windows, the shell script is actually running in msys - [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]] -} - -function sed_in_place() { - sed -e $1 $2 > "$2.bak" - mv "$2.bak" $2 -} - -function write_to_bazelrc() { - echo "$1" >> .tf_configure.bazelrc -} - -function write_action_env_to_bazelrc() { - write_to_bazelrc "build --action_env $1=\"$2\"" -} - -function python_path { - "$PYTHON_BIN_PATH" - <&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - PYTHON_BIN_PATH="" - # Retry - done - - if [ -z "$PYTHON_LIB_PATH" ]; then - # Split python_path into an array of paths, this allows path containing spaces - IFS=',' read -r -a python_lib_path <<< "$(python_path)" - - if [ 1 = "$USE_DEFAULT_PYTHON_LIB_PATH" ]; then - PYTHON_LIB_PATH=${python_lib_path[0]} - echo "Using python library path: $PYTHON_LIB_PATH" - - else - echo "Found possible Python library paths:" - for x in "${python_lib_path[@]}"; do - echo " $x" - done - set -- "${python_lib_path[@]}" - echo "Please input the desired Python library path to use. Default is [$1]" - read b || true - if [ "$b" == "" ]; then - PYTHON_LIB_PATH=${python_lib_path[0]} - echo "Using python library path: $PYTHON_LIB_PATH" - else - PYTHON_LIB_PATH="$b" - fi - fi - fi - - if [ ! -x "$PYTHON_BIN_PATH" ] || [ -d "$PYTHON_BIN_PATH" ]; then - echo "PYTHON_BIN_PATH is not executable. Is it the python binary?" - exit 1 - fi - - local python_major_version - python_major_version=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import sys; print(sys.version_info[0]);' | head -c1) - if [ -z "$python_major_version" ]; then - echo -e "\n\nERROR: Problem getting python version. Is $PYTHON_BIN_PATH the correct python binary?" - exit 1 - fi - - # Convert python path to Windows style before writing into bazel.rc - if is_windows; then - PYTHON_BIN_PATH="$(cygpath -m "$PYTHON_BIN_PATH")" - PYTHON_LIB_PATH="$(cygpath -m "$PYTHON_LIB_PATH")" - fi - - # Set-up env variables used by python_configure.bzl - write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH" - write_action_env_to_bazelrc "PYTHON_LIB_PATH" "$PYTHON_LIB_PATH" - write_to_bazelrc "build --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" - write_to_bazelrc "build --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" - write_to_bazelrc "build --force_python=py$python_major_version" - write_to_bazelrc "build --host_force_python=py$python_major_version" - write_to_bazelrc "build --python${python_major_version}_path=\"$PYTHON_BIN_PATH\"" - write_to_bazelrc "test --force_python=py$python_major_version" - write_to_bazelrc "test --host_force_python=py$python_major_version" - write_to_bazelrc "test --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" - write_to_bazelrc "test --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" - write_to_bazelrc "run --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" - write_to_bazelrc "run --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" - - # Write tools/python_bin_path.sh - echo "export PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" > tools/python_bin_path.sh -} - -function version { - echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; -} - - -bazel version > bazel.version -curr_bazel_version=$(head -n 1 bazel.version | cut -d ' ' -f3) -rm -f bazel.version - - -echo "You have bazel $curr_bazel_version installed." -if [ -z "$curr_bazel_version" ]; then - echo "WARNING: current bazel installation is not a release version." - echo "Make sure you are running at least bazel $MIN_BAZEL_VERSION." -elif [ "$(version "$MIN_BAZEL_VERSION")" -gt "$(version "$curr_bazel_version")" ]; then - echo "Please upgrade your bazel installation to version $MIN_BAZEL_VERSION or higher to build TensorFlow!" - echo "Exiting..." - exit 1 -fi - -# This file contains customized config settings. -rm -f .tf_configure.bazelrc -touch .tf_configure.bazelrc -if [[ ! -e .bazelrc ]]; then - if [[ -e "${HOME}/.bazelrc" ]]; then - echo "import ${HOME}/.bazelrc" >.bazelrc - else - touch .bazelrc - fi -fi -sed_in_place "/tf_configure/d" .bazelrc -echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc - -# Delete any leftover BUILD files from the Makefile build, which would interfere -# with Bazel parsing. -MAKEFILE_DOWNLOAD_DIR=tensorflow/contrib/makefile/downloads -if [ -d "${MAKEFILE_DOWNLOAD_DIR}" ]; then - find ${MAKEFILE_DOWNLOAD_DIR} -type f -name '*BUILD' -delete -fi - -setup_python - -## Set up MKL related environment settings -while [ "$TF_NEED_MKL" == "" ]; do - fromuser="" - read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT - fromuser="1" - case $INPUT in - [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;; - [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - * ) echo "Invalid selection: " $INPUT;; - esac -done - -OSNAME=`uname -s` - -if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL - while [ "$TF_DOWNLOAD_MKL" == "" ]; do - fromuser="" - read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT - fromuser="1" - case $INPUT in - [Yy]* ) TF_DOWNLOAD_MKL=1;; - [Nn]* ) TF_DOWNLOAD_MKL=0;; - "" ) TF_DOWNLOAD_MKL=1;; - * ) echo "Invalid selection: " $INPUT; exit 1;; - esac - done - - if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then - DST=`dirname $0` - ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz - GITHUB_RELEASE_TAG=v0.7 - MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME" - if ! [ -e "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" ]; then - curl -fSsL -o "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" "${MKLURL}" - fi - tar -xzf $DST/third_party/mkl/$ARCHIVE_BASENAME -C $DST/third_party/mkl/ - extracted_dir_name="${ARCHIVE_BASENAME%.*}" - MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name - MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` - - else - default_mkl_path=/opt/intel/mklml - fromuser="" - if [ -z "$MKL_INSTALL_PATH" ]; then - read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH - fromuser="1" - fi - if [ -z "$MKL_INSTALL_PATH" ]; then - MKL_INSTALL_PATH=$default_mkl_path - fi - # Result returned from "read" will be used unexpanded. That make "~" unusable. - # Going through one more level of expansion to handle that. - MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` - fi - - if [ "$OSNAME" == "Linux" ]; then - # Full MKL configuration - MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version? - MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION? - - # MKL-ML configuration - MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version? - MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION? - elif [ "$OSNAME" == "Darwin" ]; then - echo "Darwin is unsupported yet"; - exit 1 - fi - - if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then - ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/ - ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/ - ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ - ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include - loc=$(locate -e libdl.so.2 | sed -n 1p) - ln -sf $loc third_party/mkl/libdl.so.2 - elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then - ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/ - ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/ - ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ - ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include - loc=$(locate -e libdl.so.2 | sed -n 1p) - ln -sf $loc third_party/mkl/libdl.so.2 - else - echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists"; - exit 1 - fi - -cat > third_party/mkl/mkl.config <&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - CLANG_CUDA_COMPILER_PATH="" - # Retry -done - -# Find out where the CUDA toolkit is installed -while true; do - # Configure the Cuda SDK version to use. - if [ -z "$TF_CUDA_VERSION" ]; then - read -p "Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 8.0]: " TF_CUDA_VERSION - fi - - fromuser="" - if [ -z "$CUDA_TOOLKIT_PATH" ]; then - default_cuda_path=/usr/local/cuda - if is_windows; then - if [ -z "$CUDA_PATH" ]; then - default_cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v8.0" - else - default_cuda_path="$(cygpath -m "$CUDA_PATH")" - fi - elif is_linux; then - # If the default doesn't exist, try an alternative default. - if [ ! -d $default_cuda_path ] && [ -d /opt/cuda ]; then - default_cuda_path=/opt/cuda - fi - fi - read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH - fromuser="1" - if [ -z "$CUDA_TOOLKIT_PATH" ]; then - CUDA_TOOLKIT_PATH="$default_cuda_path" - fi - fi - - if [[ -z "$TF_CUDA_VERSION" ]]; then - TF_CUDA_EXT="" - else - TF_CUDA_EXT=".$TF_CUDA_VERSION" - fi - - if is_windows; then - CUDA_RT_LIB_PATH="lib/x64/cudart.lib" - elif is_linux; then - CUDA_RT_LIB_PATH="lib64/libcudart.so${TF_CUDA_EXT}" - elif is_macos; then - CUDA_RT_LIB_PATH="lib/libcudart${TF_CUDA_EXT}.dylib" - fi - - if [ -e "${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH}" ]; then - export CUDA_TOOLKIT_PATH - write_action_env_to_bazelrc "CUDA_TOOLKIT_PATH" "$CUDA_TOOLKIT_PATH" - export TF_CUDA_VERSION - break - fi - echo "Invalid path to CUDA $TF_CUDA_VERSION toolkit. ${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH} cannot be found" - - if [ -z "$fromuser" ]; then - exit 1 - fi - # Retry - TF_CUDA_VERSION="" - CUDA_TOOLKIT_PATH="" -done - -# Set default CUDA version if not set -if [ -z "$TF_CUDA_VERSION" ]; then - TF_CUDA_VERSION="8.0" - export TF_CUDA_VERSION -fi -write_action_env_to_bazelrc "TF_CUDA_VERSION" "$TF_CUDA_VERSION" - -# Set up which gcc nvcc should use as the host compiler -# No need to set this on Windows -while [[ "$TF_CUDA_CLANG" != "1" ]] && ! is_windows && true; do - fromuser="" - if [ -z "$GCC_HOST_COMPILER_PATH" ]; then - default_gcc_host_compiler_path=$(which gcc || true) - cuda_bin_symlink="$CUDA_TOOLKIT_PATH/bin/gcc" - if [ -L "$cuda_bin_symlink" ]; then - default_gcc_host_compiler_path=$(readlink $cuda_bin_symlink) - fi - read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH - fromuser="1" - if [ -z "$GCC_HOST_COMPILER_PATH" ]; then - GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path" - fi - fi - if [ -e "$GCC_HOST_COMPILER_PATH" ]; then - export GCC_HOST_COMPILER_PATH - write_action_env_to_bazelrc "GCC_HOST_COMPILER_PATH" "$GCC_HOST_COMPILER_PATH" - break - fi - echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - GCC_HOST_COMPILER_PATH="" - # Retry -done - -# Find out where the cuDNN library is installed -while true; do - # Configure the cuDNN version to use. - if [ -z "$TF_CUDNN_VERSION" ]; then - read -p "Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 6.0]: " TF_CUDNN_VERSION - fi - - fromuser="" - if [ -z "$CUDNN_INSTALL_PATH" ]; then - default_cudnn_path=${CUDA_TOOLKIT_PATH} - read -p "Please specify the location where cuDNN $TF_CUDNN_VERSION library is installed. Refer to README.md for more details. [Default is $default_cudnn_path]: " CUDNN_INSTALL_PATH - fromuser="1" - if [ -z "$CUDNN_INSTALL_PATH" ]; then - CUDNN_INSTALL_PATH=$default_cudnn_path - fi - # Result returned from "read" will be used unexpanded. That make "~" unusable. - # Going through one more level of expansion to handle that. - CUDNN_INSTALL_PATH=`"${PYTHON_BIN_PATH}" -c "import os; print(os.path.realpath(os.path.expanduser('${CUDNN_INSTALL_PATH}')))"` - fi - - if [[ -z "$TF_CUDNN_VERSION" ]]; then - TF_CUDNN_EXT="" - else - TF_CUDNN_EXT=".$TF_CUDNN_VERSION" - fi - - if is_windows; then - CUDA_DNN_LIB_PATH="lib/x64/cudnn.lib" - CUDA_DNN_LIB_ALT_PATH="lib/x64/cudnn.lib" - elif is_linux; then - CUDA_DNN_LIB_PATH="lib64/libcudnn.so${TF_CUDNN_EXT}" - CUDA_DNN_LIB_ALT_PATH="libcudnn.so${TF_CUDNN_EXT}" - elif is_macos; then - CUDA_DNN_LIB_PATH="lib/libcudnn${TF_CUDNN_EXT}.dylib" - CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}.dylib" - fi - - if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" ] || [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then - export TF_CUDNN_VERSION - write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION" - export CUDNN_INSTALL_PATH - write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "$CUDNN_INSTALL_PATH" - break - fi - - if is_linux; then - if ! type ldconfig > /dev/null 2>&1; then - LDCONFIG_BIN=/sbin/ldconfig - else - LDCONFIG_BIN=ldconfig - fi - CUDNN_PATH_FROM_LDCONFIG="$($LDCONFIG_BIN -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')" - if [ -e "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" ]; then - export TF_CUDNN_VERSION - export CUDNN_INSTALL_PATH - CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})" - write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "$CUDNN_INSTALL_PATH" - break - fi - fi - echo "Invalid path to cuDNN ${CUDNN_VERSION} toolkit. Neither of the following two files can be found:" - echo "${CUDNN_INSTALL_PATH}/${CUDA_DNN_LIB_PATH}" - echo "${CUDNN_INSTALL_PATH}/${CUDA_DNN_LIB_ALT_PATH}" - if is_linux; then - echo "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" - fi - - if [ -z "$fromuser" ]; then - exit 1 - fi - # Retry - TF_CUDNN_VERSION="" - CUDNN_INSTALL_PATH="" -done - -# Set default CUDNN version if not set -if [ -z "$TF_CUDNN_VERSION" ]; then - TF_CUDNN_VERSION="6" - export TF_CUDNN_VERSION -fi -write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION" - -# Configure the compute capabilities that TensorFlow builds for. -# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work. -while true; do - fromuser="" - default_cuda_compute_capabilities="3.5,5.2" - if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then -cat << EOF -Please specify a list of comma-separated Cuda compute capabilities you want to build with. -You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus. -Please note that each additional compute capability significantly increases your build time and binary size. -EOF - read -p "[Default is: \"3.5,5.2\"]: " TF_CUDA_COMPUTE_CAPABILITIES - fromuser=1 - fi - if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then - TF_CUDA_COMPUTE_CAPABILITIES=$default_cuda_compute_capabilities - fi - # Check whether all capabilities from the input is valid - COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES//,/ } - ALL_VALID=1 - for CAPABILITY in $COMPUTE_CAPABILITIES; do - if [[ ! "$CAPABILITY" =~ [0-9]+.[0-9]+ ]]; then - echo "Invalid compute capability: " $CAPABILITY - ALL_VALID=0 - break - fi - done - if [ "$ALL_VALID" == "0" ]; then - if [ -z "$fromuser" ]; then - exit 1 - fi - else - export TF_CUDA_COMPUTE_CAPABILITIES - write_action_env_to_bazelrc "TF_CUDA_COMPUTE_CAPABILITIES" "$TF_CUDA_COMPUTE_CAPABILITIES" - break - fi - TF_CUDA_COMPUTE_CAPABILITIES="" -done - -if is_windows; then - # The following three variables are needed for MSVC toolchain configuration in Bazel - export CUDA_PATH="$CUDA_TOOLKIT_PATH" - export CUDA_COMPUTE_CAPABILITIES="$TF_CUDA_COMPUTE_CAPABILITIES" - export NO_WHOLE_ARCHIVE_OPTION=1 - write_action_env_to_bazelrc "CUDA_PATH" "$CUDA_PATH" - write_action_env_to_bazelrc "CUDA_COMPUTE_CAPABILITIES" "$CUDA_COMPUTE_CAPABILITIES" - write_action_env_to_bazelrc "NO_WHOLE_ARCHIVE_OPTION" "1" - write_to_bazelrc "build --config=win-cuda" - write_to_bazelrc "test --config=win-cuda" -else - # If CUDA is enabled, always use GPU during build and test. - write_to_bazelrc "build --config=cuda" - write_to_bazelrc "test --config=cuda" -fi - -# end of if "$TF_NEED_CUDA" == "1" -fi - -# OpenCL configuration - -if [ "$TF_NEED_OPENCL" == "1" ]; then - -# Determine which C++ compiler should be used as the host compiler -while true; do - fromuser="" - if [ -z "$HOST_CXX_COMPILER" ]; then - default_cxx_host_compiler=$(which g++ || true) - read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER - fromuser="1" - if [ -z "$HOST_CXX_COMPILER" ]; then - HOST_CXX_COMPILER=$default_cxx_host_compiler - fi - fi - if [ -e "$HOST_CXX_COMPILER" ]; then - export HOST_CXX_COMPILER - write_action_env_to_bazelrc "HOST_CXX_COMPILER" "$HOST_CXX_COMPILER" - break - fi - echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - HOST_CXX_COMPILER="" - # Retry -done - -# Determine which C compiler should be used as the host compiler -while true; do - fromuser="" - if [ -z "$HOST_C_COMPILER" ]; then - default_c_host_compiler=$(which gcc || true) - read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER - fromuser="1" - if [ -z "$HOST_C_COMPILER" ]; then - HOST_C_COMPILER=$default_c_host_compiler - fi - fi - if [ -e "$HOST_C_COMPILER" ]; then - export HOST_C_COMPILER - write_action_env_to_bazelrc "HOST_C_COMPILER" "$HOST_C_COMPILER" - break - fi - echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - HOST_C_COMPILER="" - # Retry -done - -while true; do - # Configure the OPENCL version to use. - TF_OPENCL_VERSION="1.2" - - # Point to ComputeCpp root - if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then - default_computecpp_toolkit_path=/usr/local/computecpp - read -p "Please specify the location where ComputeCpp for SYCL $TF_OPENCL_VERSION is installed. [Default is $default_computecpp_toolkit_path]: " COMPUTECPP_TOOLKIT_PATH - fromuser="1" - if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then - COMPUTECPP_TOOLKIT_PATH=$default_computecpp_toolkit_path - fi - fi - - if is_linux; then - SYCL_RT_LIB_PATH="lib/libComputeCpp.so" - fi - - if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then - export COMPUTECPP_TOOLKIT_PATH - write_action_env_to_bazelrc "COMPUTECPP_TOOLKIT_PATH" "$COMPUTECPP_TOOLKIT_PATH" - break - fi - echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found" - - if [ -z "$fromuser" ]; then - exit 1 - fi - # Retry - TF_OPENCL_VERSION="" - COMPUTECPP_TOOLKIT_PATH="" -done - -# end of if "$TF_NEED_OPENCL" == "1" -fi - - -while [ "$TF_NEED_MPI" == "" ]; do - read -p "Do you wish to build TensorFlow with "\ -"MPI support? [y/N] " INPUT - case $INPUT in - [Yy]* ) echo "MPI support will be enabled for "\ -"TensorFlow"; TF_NEED_MPI=1;; - [Nn]* ) echo "MPI support will not be enabled for "\ -"TensorFlow"; TF_NEED_MPI=0;; - "" ) echo "MPI support will not be enabled for "\ -"TensorFlow"; TF_NEED_MPI=0;; - * ) echo "Invalid selection: " $INPUT;; - esac -done - -# Find out where the MPI toolkit is installed -while true; do - if [ "$TF_NEED_MPI" == "0" ]; then - break; - fi - - fromuser="" - if [ -z "$MPI_HOME" ]; then - #Get the base folder by removing the bin path - default_mpi_path=$(dirname $(dirname $(which mpirun)) || dirname $(dirname $(which mpiexec)) || true) - read -p "Please specify the MPI toolkit folder. [Default is $default_mpi_path]: " MPI_HOME - fromuser="1" - if [ -z "$MPI_HOME" ]; then - MPI_HOME=$default_mpi_path - fi - fi - - #Check that the include and library folders are where we expect them to be - if [ -e "$MPI_HOME/include" ] && [ -e "$MPI_HOME/lib" ]; then - break - fi - - echo "Invalid path to the MPI Toolkit. ${MPI_HOME}/include or ${MPI_HOME}/lib cannot be found." - if [ -z "$fromuser" ]; then - exit 1 - fi - - # Retry - MPI_HOME="" -done - - -if [ "$TF_NEED_MPI" == "1" ]; then - write_to_bazelrc 'build --define with_mpi_support=true' - - #Link the MPI header files - ln -sf "${MPI_HOME}/include/mpi.h" third_party/mpi/mpi.h - - - #Determine if we use OpenMPI or MVAPICH, these require different header files - #to be included here to make bazel dependency checker happy - - if [ -e "${MPI_HOME}/include/mpi_portable_platform.h" ]; then - #OpenMPI - ln -sf "${MPI_HOME}/include/mpi_portable_platform.h" third_party/mpi/ - sed -i -e "s/MPI_LIB_IS_OPENMPI=False/MPI_LIB_IS_OPENMPI=True/" third_party/mpi/mpi.bzl - else - #MVAPICH / MPICH - ln -sf "${MPI_HOME}/include/mpio.h" third_party/mpi/ - ln -sf "${MPI_HOME}/include/mpicxx.h" third_party/mpi/ - sed -i -e "s/MPI_LIB_IS_OPENMPI=True/MPI_LIB_IS_OPENMPI=False/" third_party/mpi/mpi.bzl - fi - - - if [ -e "${MPI_HOME}/lib/libmpi.so" ]; then - ln -sf "${MPI_HOME}/lib/libmpi.so" third_party/mpi/ - else - echo "Cannot find the MPI library file in ${MPI_HOME}/lib " - exit 1 - fi +if [ -z "$PYTHON_BIN_PATH" ]; then + PYTHON_BIN_PATH=$(which python || which python3 || true) fi +# Set all env variables +"$PYTHON_BIN_PATH" configure.py echo "Configuration finished" + diff --git a/configure.py b/configure.py new file mode 100644 index 0000000000000000000000000000000000000000..edb0a47ee611d3f355a1dd1622847d738bfe0122 --- /dev/null +++ b/configure.py @@ -0,0 +1,950 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""configure script to get build parameters from user.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import errno +import os +import platform +import re +import site +import subprocess +import sys + +_TF_BAZELRC = '.tf_configure.bazelrc' +_DEFAULT_CUDA_VERSION = '8.0' +_DEFAULT_CUDNN_VERSION = '6' +_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' +_DEFAULT_CUDA_PATH = '/usr/local/cuda' +_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' +_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' + 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) +_TF_OPENCL_VERSION = '1.2' +_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' + + +def is_windows(): + return platform.system() == 'Windows' + + +def is_linux(): + return platform.system() == 'Linux' + + +def is_macos(): + return platform.system() == 'Darwin' + + +def is_ppc64le(): + return platform.machine() == 'ppc64le' + + +def get_input(question): + try: + try: + answer = raw_input(question) + except NameError: + answer = input(question) # pylint: disable=bad-builtin + except EOFError: + answer = '' + return answer + + +def symlink_force(target, link_name): + """Force symlink, equivalent of 'ln -sf'. + + Args: + target: items to link to. + link_name: name of the link. + """ + try: + os.symlink(target, link_name) + except OSError as e: + if e.errno == errno.EEXIST: + os.remove(link_name) + os.symlink(target, link_name) + else: + raise e + + +def sed_in_place(filename, old, new): + """Replace old string with new string in file. + + Args: + filename: string for filename. + old: string to replace. + new: new string to replace to. + """ + with open(filename, 'r') as f: + filedata = f.read() + newdata = filedata.replace(old, new) + with open(filename, 'w') as f: + f.write(newdata) + + +def remove_line_with(filename, token): + """Remove lines that contain token from file. + + Args: + filename: string for filename. + token: string token to check if to remove a line from file or not. + """ + with open(filename, 'r') as f: + filedata = f.read() + + with open(filename, 'w') as f: + for line in filedata.strip().split('\n'): + if token not in line: + f.write(line + '\n') + + +def write_to_bazelrc(line): + with open(_TF_BAZELRC, 'a') as f: + f.write(line + '\n') + + +def write_action_env_to_bazelrc(var_name, var): + write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var))) + + +def run_shell(cmd): + return subprocess.check_output(cmd, shell=True).decode('UTF-8').strip() + + +def cygpath(path): + """Convert path from posix to windows.""" + return run_shell('cygpath -m "%s"' % path) + + +def get_python_path(environ_cp): + """Get the python site package paths.""" + python_paths = [] + if environ_cp.get('PYTHONPATH'): + python_paths = environ_cp.get('PYTHONPATH').split(':') + try: + library_paths = site.getsitepackages() + except AttributeError: + from distutils.sysconfig import get_python_lib # pylint: disable=g-import-not-at-top + library_paths = [get_python_lib()] + all_paths = set(python_paths + library_paths) + + paths = [] + for path in all_paths: + if os.path.isdir(path): + paths.append(path) + return paths + + +def setup_python(environ_cp): + """Setup python related env variables.""" + # Get PYTHON_BIN_PATH, default is the current running python. + default_python_bin_path = sys.executable + ask_python_bin_path = ('Please specify the location of python. [Default is ' + '%s]: ') % default_python_bin_path + while True: + python_bin_path = get_from_env_or_user_or_default( + environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path, + default_python_bin_path) + # Check if the path is valid + if (os.path.isfile(python_bin_path) and os.access( + python_bin_path, os.X_OK)) or (os.path.isdir(python_bin_path)): + break + elif not os.path.exists(python_bin_path): + print('Invalid python path: %s cannot be found.' % python_bin_path) + else: + print('%s is not executable. Is it the python binary?' % python_bin_path) + environ_cp['PYTHON_BIN_PATH'] = '' + + # Get PYTHON_LIB_PATH + python_lib_path = environ_cp.get('PYTHON_LIB_PATH') + if not python_lib_path: + python_lib_paths = get_python_path(environ_cp) + if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1': + python_lib_path = python_lib_paths[0] + else: + print('Found possible Python library paths:\n%s' % + '\n'.join(python_lib_paths)) + default_python_lib_path = python_lib_paths[0] + python_lib_path = get_input( + 'Please input the desired Python library path to use. Default is %s' + % python_lib_paths[0]) + if not python_lib_path: + python_lib_path = default_python_lib_path + environ_cp['PYTHON_LIB_PATH'] = python_lib_path + + python_major_version = sys.version_info[0] + # Convert python path to Windows style before writing into bazel.rc + if is_windows(): + python_bin_path = cygpath(python_bin_path) + python_lib_path = cygpath(python_lib_path) + + # Set-up env variables used by python_configure.bzl + write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) + write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) + write_to_bazelrc('build --define PYTHON_BIN_PATH="%s"' % python_bin_path) + write_to_bazelrc('build --define PYTHON_LIB_PATH="%s"' % python_lib_path) + write_to_bazelrc('build --force_python=py%s' % python_major_version) + write_to_bazelrc('build --host_force_python=py%s' % python_major_version) + write_to_bazelrc('build --python%s_path=\"%s"' % (python_major_version, + python_bin_path)) + write_to_bazelrc('test --force_python=py%s' % python_major_version) + write_to_bazelrc('test --host_force_python=py%s' % python_major_version) + write_to_bazelrc('test --define PYTHON_BIN_PATH="%s"' % python_bin_path) + write_to_bazelrc('test --define PYTHON_LIB_PATH="%s"' % python_lib_path) + write_to_bazelrc('run --define PYTHON_BIN_PATH="%s"' % python_bin_path) + write_to_bazelrc('run --define PYTHON_LIB_PATH="%s"' % python_lib_path) + environ_cp['PYTHON_BIN_PATH'] = python_bin_path + + # Write tools/python_bin_path.sh + with open('tools/python_bin_path.sh', 'w') as f: + f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) + + +def reset_tf_configure_bazelrc(): + """Reset file that contains customized config settings.""" + open(_TF_BAZELRC, 'w').close() + + home = os.path.expanduser('~') + if not os.path.exists('.bazelrc'): + if os.path.exists(os.path.join(home, '.bazelrc')): + with open('.bazelrc', 'a') as f: + f.write('import %s/.bazelrc\n' % home) + else: + open('.bazelrc', 'w').close() + + remove_line_with('.bazelrc', 'tf_configure') + with open('.bazelrc', 'a') as f: + f.write('import %workspace%/.tf_configure.bazelrc\n') + + +def run_gen_git_source(environ_cp): + """Run the gen_git_source to create links. + + The links are for bazel to track dependencies for git hash propagation. + + Args: + environ_cp: copy of the os.environ. + """ + cmd = '"%s" tensorflow/tools/git/gen_git_source.py --configure %s' % ( + environ_cp.get('PYTHON_BIN_PATH'), os.getcwd()) + os.system(cmd) + + +def cleanup_makefile(): + """Delete any leftover BUILD files from the Makefile build. + + These files could interfere with Bazel parsing. + """ + makefile_download_dir = 'tensorflow/contrib/makefile/downloads' + if os.path.isdir(makefile_download_dir): + for root, _, filenames in os.walk(makefile_download_dir): + for f in filenames: + if f.endswith('BUILD'): + os.remove(os.path.join(root, f)) + + +def get_var(environ_cp, + var_name, + query_item, + enabled_by_default, + question=None, + yes_reply=None, + no_reply=None): + """Get boolean input from user. + + If var_name is not set in env, ask user to enable query_item or not. If the + response is empty, use the default. + + Args: + environ_cp: copy of the os.environ. + var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". + query_item: string for feature related to the variable, e.g. "Hadoop File + System". + enabled_by_default: boolean for default behavior. + question: optional string for how to ask for user input. + yes_reply: optionanl string for reply when feature is enabled. + no_reply: optional string for reply when feature is disabled. + + Returns: + boolean value of the variable. + """ + if not question: + question = 'Do you wish to build TensorFlow with %s support?' % query_item + if not yes_reply: + yes_reply = '%s support will be enabled for TensorFlow.' % query_item + if not no_reply: + no_reply = 'No %s' % yes_reply + + yes_reply += '\n' + no_reply += '\n' + + if enabled_by_default: + question += ' [Y/n]: ' + else: + question += ' [y/N]: ' + + var = environ_cp.get(var_name) + while var is None: + user_input_origin = get_input(question) + user_input = user_input_origin.strip().lower() + if user_input == 'y': + print(yes_reply) + var = True + elif user_input == 'n': + print(no_reply) + var = False + elif not user_input: + if enabled_by_default: + print(yes_reply) + var = True + else: + print(no_reply) + var = False + else: + print('Invalid selection: %s' % user_input_origin) + return var + + +def set_build_var(environ_cp, var_name, query_item, option_name, + enabled_by_default): + """Set if query_item will be enabled for the build. + + Ask user if query_item will be enabled. Default is used if no input is given. + Set subprocess environment variable and write to .bazelrc if enabled. + + Args: + environ_cp: copy of the os.environ. + var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". + query_item: string for feature related to the variable, e.g. "Hadoop File + System". + option_name: string for option to define in .bazelrc. + enabled_by_default: boolean for default behavior. + """ + + var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) + environ_cp[var_name] = var + if var == '1': + write_to_bazelrc('build --define %s=true' % option_name) + + +def set_action_env_var(environ_cp, + var_name, + query_item, + enabled_by_default, + question=None, + yes_reply=None, + no_reply=None): + """Set boolean action_env variable. + + Ask user if query_item will be enabled. Default is used if no input is given. + Set environment variable and write to .bazelrc. + + Args: + environ_cp: copy of the os.environ. + var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". + query_item: string for feature related to the variable, e.g. "Hadoop File + System". + enabled_by_default: boolean for default behavior. + question: optional string for how to ask for user input. + yes_reply: optionanl string for reply when feature is enabled. + no_reply: optional string for reply when feature is disabled. + """ + var = int( + get_var(environ_cp, var_name, query_item, enabled_by_default, question, + yes_reply, no_reply)) + + write_action_env_to_bazelrc(var_name, var) + environ_cp[var_name] = str(var) + + +def check_bazel_version(min_version): + """Check installed bezel version is at least min_version. + + Args: + min_version: string for minimum bazel version. + """ + try: + curr_version = run_shell('bazel --batch version') + except subprocess.CalledProcessError: + print('Cannot find bazel. Please install bazel.') + sys.exit(0) + + for line in curr_version.split('\n'): + if 'Build label: ' in line: + curr_version = line.split('Build label: ')[1] + break + + min_version_segments = min_version.split('.') + curr_version_segments = curr_version.split('.') + + # Check if current bazel version can be detected properly. + for seg in curr_version_segments: + if not seg.isdigit(): + print('WARNING: current bazel installation is not a release version.') + print('Make sure you are running at least bazel %s' % min_version) + return + + min_version_str = ''.join(['%03d' % int(seg) for seg in min_version_segments]) + curr_version_str = ''.join( + ['%03d' % int(seg) for seg in curr_version_segments]) + if int(curr_version_str) < int(min_version_str): + print('Please upgrade your bazel installation to version %s or higher to ' + 'build TensorFlow!' % min_version) + sys.exit(0) + + +def set_cc_opt_flags(environ_cp): + """Set up architecture-dependent optimization flags. + + Also append CC optimization flags to bazel.rc.. + + Args: + environ_cp: copy of the os.environ. + """ + if is_ppc64le(): + # gcc on ppc64le does not support -march, use mcpu instead + default_cc_opt_flags = '-mcpu=native' + else: + default_cc_opt_flags = '-march=native' + question = ('Please specify optimization flags to use during compilation when' + ' bazel option "--config=opt" is specified [Default is %s]: ' + ) % default_cc_opt_flags + cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', + question, default_cc_opt_flags) + for opt in cc_opt_flags.split(): + write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt)) + + +def set_tf_cuda_clang(environ_cp): + """set TF_CUDA_CLANG action_env. + + Args: + environ_cp: copy of the os.environ. + """ + question = 'Do you want to use clang as CUDA compiler?' + yes_reply = 'Clang will be used as CUDA compiler.' + no_reply = 'nvcc will be used as CUDA compiler.' + set_action_env_var( + environ_cp, + 'TF_CUDA_CLANG', + None, + False, + question=question, + yes_reply=yes_reply, + no_reply=no_reply) + + +def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, + var_default): + """Get var_name either from env, or user or default. + + If var_name has been set as environment variable, use the preset value, else + ask for user input. If no input is provided, the default is used. + + Args: + environ_cp: copy of the os.environ. + var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". + ask_for_var: string for how to ask for user input. + var_default: default value string. + + Returns: + string value for var_name + """ + var = environ_cp.get(var_name) + if not var: + var = get_input(ask_for_var) + if not var: + var = var_default + return var + + +def set_clang_cuda_compiler_path(environ_cp): + """Set CLANG_CUDA_COMPILER_PATH.""" + default_clang_path = run_shell('which clang || true') + ask_clang_path = ('Please specify which clang should be used as device and ' + 'host compiler. [Default is %s]: ') % default_clang_path + + while True: + clang_cuda_compiler_path = get_from_env_or_user_or_default( + environ_cp, 'CLANG_CUDA_COMPILER_PATH', ask_clang_path, + default_clang_path) + if os.path.exists(clang_cuda_compiler_path): + break + + # Reset and retry + print('Invalid clang path: %s cannot be found.' % clang_cuda_compiler_path) + environ_cp['CLANG_CUDA_COMPILER_PATH'] = '' + + # Set CLANG_CUDA_COMPILER_PATH + environ_cp['CLANG_CUDA_COMPILER_PATH'] = clang_cuda_compiler_path + write_action_env_to_bazelrc('CLANG_CUDA_COMPILER_PATH', + clang_cuda_compiler_path) + + +def set_gcc_host_compiler_path(environ_cp): + """Set GCC_HOST_COMPILER_PATH.""" + default_gcc_host_compiler_path = run_shell('which gcc || true') + cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH') + + if os.path.islink(cuda_bin_symlink): + # os.readlink is only available in linux + default_gcc_host_compiler_path = run_shell('readlink %s' % cuda_bin_symlink) + + ask_gcc_path = ( + 'Please specify which gcc should be used by nvcc as the ' + 'host compiler. [Default is %s]: ') % default_gcc_host_compiler_path + while True: + gcc_host_compiler_path = get_from_env_or_user_or_default( + environ_cp, 'GCC_HOST_COMPILER_PATH', ask_gcc_path, + default_gcc_host_compiler_path) + + if os.path.exists(gcc_host_compiler_path): + break + + # Reset and retry + print('Invalid gcc path. %s cannot be found' % gcc_host_compiler_path) + environ_cp['GCC_HOST_COMPILER_PATH'] = '' + + # Set GCC_HOST_COMPILER_PATH + environ_cp['GCC_HOST_COMPILER_PATH'] = gcc_host_compiler_path + write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) + + +def set_tf_cuda_version(environ_cp): + """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" + ask_cuda_version = ( + 'Please specify the CUDA SDK version you want to use, ' + 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION + + while True: + # Configure the Cuda SDK version to use. + tf_cuda_version = get_from_env_or_user_or_default( + environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) + + # Find out where the CUDA toolkit is installed + default_cuda_path = _DEFAULT_CUDA_PATH + if is_windows(): + default_cuda_path = cygpath( + environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN)) + elif is_linux(): + # If the default doesn't exist, try an alternative default. + if (not os.path.exists(default_cuda_path) + ) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX): + default_cuda_path = _DEFAULT_CUDA_PATH_LINUX + ask_cuda_path = ('Please specify the location where CUDA %s toolkit is' + ' installed. Refer to README.md for more details. ' + '[Default is %s]: ') % (tf_cuda_version, default_cuda_path) + cuda_toolkit_path = get_from_env_or_user_or_default( + environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path) + + if is_windows(): + cuda_rt_lib_path = 'lib/x64/cudart.lib' + elif is_linux(): + cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version + elif is_macos(): + cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version + + cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path) + if os.path.exists(cuda_toolkit_path_full): + break + + # Reset and retry + print('Invalid path to CUDA %s toolkit. %s cannot be found' % + (tf_cuda_version, cuda_toolkit_path_full)) + environ_cp['TF_CUDA_VERSION'] = '' + environ_cp['CUDA_TOOLKIT_PATH'] = '' + + # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION + environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path + write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path) + environ_cp['TF_CUDA_VERSION'] = tf_cuda_version + write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version) + + +def set_tf_cunn_version(environ_cp): + """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" + ask_cudnn_version = ( + '"Please specify the cuDNN version you want to use. ' + '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION + + while True: + tf_cudnn_version = get_from_env_or_user_or_default( + environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, + _DEFAULT_CUDNN_VERSION) + + default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') + ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' + 'installed. Refer to README.md for more details. [Default' + ' is %s]:') % (tf_cudnn_version, default_cudnn_path) + cudnn_install_path = get_from_env_or_user_or_default( + environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + cudnn_install_path = os.path.realpath( + os.path.expanduser(cudnn_install_path)) + if is_windows(): + cudnn_install_path = cygpath(cudnn_install_path) + + if is_windows(): + cuda_dnn_lib_path = 'lib/x64/cudnn.lib' + cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib' + elif is_linux(): + cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version + cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version + elif is_macos(): + cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version + cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version + + cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path) + cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path, + cuda_dnn_lib_alt_path) + if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists( + cuda_dnn_lib_alt_path_full): + break + + # Try another alternative for Linux + if is_linux(): + if subprocess.call(['which', 'ldconfig']): + ldconfig_bin = '/sbin/ldconfig' + else: + ldconfig_bin = 'ldconfig' + cudnn_path_from_ldconfig = run_shell( + r'%s -p | sed -n "s/.*libcudnn.so .* => \(.*\)/\\1/p"' % ldconfig_bin) + if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)): + cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) + break + + # Reset and Retry + print( + 'Invalid path to cuDNN %s toolkit. None of the following files can be ' + 'found:' % tf_cudnn_version) + print(cuda_dnn_lib_path_full) + print(cuda_dnn_lib_alt_path_full) + if is_linux(): + print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)) + + environ_cp['TF_CUDNN_VERSION'] = '' + + # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION + environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path + write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path) + environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version + write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) + + +def get_native_cuda_compute_capabilities(environ_cp): + """Get native cuda compute capabilities. + + Args: + environ_cp: copy of the os.environ. + Returns: + string of native cuda compute capabilities, separated by comma. + """ + device_query_bin = os.path.join( + environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery') + cmd = (r'"%s" | grep "Capability" | grep -o "[0-9]*\.[0-9]*" | sed ' + '":a;{N;s/\\n/,/};ba"') % device_query_bin + try: + output = run_shell(cmd) + except subprocess.CalledProcessError: + output = '' + return output + + +def set_tf_cuda_compute_capabilities(environ_cp): + """Set TF_CUDA_COMPUTE_CAPABILITIES.""" + while True: + native_cuda_compute_capabilities = get_native_cuda_compute_capabilities( + environ_cp) + if not native_cuda_compute_capabilities: + default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES + else: + default_cuda_compute_capabilities = native_cuda_compute_capabilities + + ask_cuda_compute_capabilities = ( + 'Please specify a list of comma-separated ' + 'Cuda compute capabilities you want to ' + 'build with.\nYou can find the compute ' + 'capability of your device at: ' + 'https://developer.nvidia.com/cuda-gpus.\nPlease' + ' note that each additional compute ' + 'capability significantly increases your ' + 'build time and binary size. [Default is: %s]' % + default_cuda_compute_capabilities) + tf_cuda_compute_capabilities = get_from_env_or_user_or_default( + environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', + ask_cuda_compute_capabilities, default_cuda_compute_capabilities) + # Check whether all capabilities from the input is valid + all_valid = True + for compute_capability in tf_cuda_compute_capabilities.split(','): + if not re.match('[0-9]+.[0-9]+', compute_capability): + print('Invalid compute capability: ' % compute_capability) + all_valid = False + + if all_valid: + break + + # Reset and Retry + environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = '' + + # Set TF_CUDA_COMPUTE_CAPABILITIES + environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = tf_cuda_compute_capabilities + write_action_env_to_bazelrc('TF_CUDA_COMPUTE_CAPABILITIES', + tf_cuda_compute_capabilities) + + +def set_other_cuda_vars(environ_cp): + """Set other CUDA related variables.""" + if is_windows(): + # The following three variables are needed for MSVC toolchain configuration + # in Bazel + environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH') + environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get( + 'TF_CUDA_COMPUTE_CAPABILITIES') + environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1 + write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH')) + write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE', + environ_cp.get('CUDA_COMPUTE_CAPABILITIE')) + write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION', + environ_cp.get('NO_WHOLE_ARCHIVE_OPTION')) + write_to_bazelrc('build --config=win-cuda') + write_to_bazelrc('test --config=win-cuda') + else: + # If CUDA is enabled, always use GPU during build and test. + if environ_cp.get('TF_CUDA_CLANG') == '1': + write_to_bazelrc('build --config=cuda_clang') + write_to_bazelrc('test --config=cuda_clang') + else: + write_to_bazelrc('build --config=cuda') + write_to_bazelrc('test --config=cuda') + + +def set_host_cxx_compiler(environ_cp): + """Set HOST_CXX_COMPILER.""" + default_cxx_host_compiler = run_shell('which g++ || true') + ask_cxx_host_compiler = ( + 'Please specify which C++ compiler should be used as' + ' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler + + while True: + host_cxx_compiler = get_from_env_or_user_or_default( + environ_cp, 'HOST_CXX_COMPILER', ask_cxx_host_compiler, + default_cxx_host_compiler) + if os.path.exists(host_cxx_compiler): + break + + # Reset and retry + print('Invalid C++ compiler path. %s cannot be found' % host_cxx_compiler) + environ_cp['HOST_CXX_COMPILER'] = '' + + # Set HOST_CXX_COMPILER + environ_cp['HOST_CXX_COMPILER'] = host_cxx_compiler + write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler) + + +def set_host_c_compiler(environ_cp): + """Set HOST_C_COMPILER.""" + default_c_host_compiler = run_shell('which gcc || true') + ask_c_host_compiler = ( + 'Please specify which C compiler should be used as the' + ' host C compiler. [Default is %s]: ') % default_c_host_compiler + + while True: + host_c_compiler = get_from_env_or_user_or_default( + environ_cp, 'HOST_C_COMPILER', ask_c_host_compiler, + default_c_host_compiler) + if os.path.exists(host_c_compiler): + break + + # Reset and retry + print('Invalid C compiler path. %s cannot be found' % host_c_compiler) + environ_cp['HOST_C_COMPILER'] = '' + + # Set HOST_C_COMPILER + environ_cp['HOST_C_COMPILER'] = host_c_compiler + write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) + + +def set_computecpp_toolkit_path(environ_cp): + """Set COMPUTECPP_TOOLKIT_PATH.""" + ask_computecpp_toolkit_path = ('Please specify the location where ComputeCpp ' + 'for SYCL %s is installed. [Default is %s]: ' + ) % (_TF_OPENCL_VERSION, + _DEFAULT_COMPUTECPP_TOOLKIT_PATH) + + while True: + computecpp_toolkit_path = get_from_env_or_user_or_default( + environ_cp, 'COMPUTECPP_TOOLKIT_PATH', ask_computecpp_toolkit_path, + _DEFAULT_COMPUTECPP_TOOLKIT_PATH) + if is_linux(): + sycl_rt_lib_path = 'lib/libComputeCpp.so' + else: + sycl_rt_lib_path = '' + + sycl_rt_lib_path_full = os.path.join(computecpp_toolkit_path, + sycl_rt_lib_path) + if os.path.exists(sycl_rt_lib_path_full): + break + + print('Invalid SYCL %s library path. %s cannot be found' % + (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) + environ_cp['COMPUTECPP_TOOLKIT_PATH'] = '' + + # Set COMPUTECPP_TOOLKIT_PATH + environ_cp['COMPUTECPP_TOOLKIT_PATH'] = computecpp_toolkit_path + write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', + computecpp_toolkit_path) + + +def set_mpi_home(environ_cp): + """Set MPI_HOME.""" + cmd = ('dirname $(dirname $(which mpirun)) || dirname $(dirname $(which ' + 'mpiexec)) || true') + default_mpi_home = run_shell(cmd) + ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: ' + ) % default_mpi_home + while True: + mpi_home = get_from_env_or_user_or_default(environ_cp, 'MPI_HOME', + ask_mpi_home, default_mpi_home) + + if os.path.exists(os.path.join(mpi_home, 'include')) and os.path.exists( + os.path.join(mpi_home, 'lib')): + break + + print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % + (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')))) + environ_cp['MPI_HOME'] = '' + + # Set MPI_HOME + environ_cp['MPI_HOME'] = str(mpi_home) + + +def set_other_mpi_vars(environ_cp): + """Set other MPI related variables.""" + # Link the MPI header files + mpi_home = environ_cp.get('MPI_HOME') + symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h') + + # Determine if we use OpenMPI or MVAPICH, these require different header files + # to be included here to make bazel dependency checker happy + if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')): + symlink_force( + os.path.join(mpi_home, 'include/mpi_portable_platform.h'), + 'third_party/mpi/mpi_portable_platform.h') + # TODO(gunan): avoid editing files in configure + sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=False', + 'MPI_LIB_IS_OPENMPI=True') + else: + # MVAPICH / MPICH + symlink_force( + os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h') + symlink_force( + os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h') + # TODO(gunan): avoid editing files in configure + sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=True', + 'MPI_LIB_IS_OPENMPI=False') + + if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')): + symlink_force( + os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so') + else: + raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) + + +def set_mkl(): + write_to_bazelrc('build:mkl --define with_mkl_support=true') + write_to_bazelrc('build:mkl --define using_mkl=true') + write_to_bazelrc('build:mkl -c opt') + write_to_bazelrc('build:mkl --copt="-DEIGEN_USE_VML"') + print( + 'Add "--config=mkl" to your bazel command to build with MKL ' + 'support.\nPlease note that MKL on MacOS or windows is still not ' + 'supported.\nIf you would like to use a local MKL instead of ' + 'downloading, please set the environment variable \"TF_MKL_ROOT\" every ' + 'time before build.') + + +def main(): + # Make a copy of os.environ to be clear when functions and getting and setting + # environment variables. + environ_cp = dict(os.environ) + + check_bazel_version('0.4.5') + + reset_tf_configure_bazelrc() + cleanup_makefile() + setup_python(environ_cp) + run_gen_git_source(environ_cp) + + if is_windows(): + environ_cp['TF_NEED_GCP'] = '0' + environ_cp['TF_NEED_HDFS'] = '0' + environ_cp['TF_NEED_JEMALLOC'] = '0' + environ_cp['TF_NEED_OPENCL'] = '0' + environ_cp['TF_CUDA_CLANG'] = '0' + + if is_macos(): + environ_cp['TF_NEED_JEMALLOC'] = '0' + + set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', + 'with_jemalloc', True) + set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', + 'with_gcp_support', False) + set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', + 'with_hdfs_support', False) + set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', + False) + set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', + False) + + set_action_env_var(environ_cp, 'TF_NEED_OPENCL', 'OpenCL', False) + if environ_cp.get('TF_NEED_OPENCL') == '1': + set_host_cxx_compiler(environ_cp) + set_host_c_compiler(environ_cp) + set_computecpp_toolkit_path(environ_cp) + + set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) + if environ_cp.get('TF_NEED_CUDA') == '1': + set_tf_cuda_version(environ_cp) + set_tf_cunn_version(environ_cp) + set_tf_cuda_compute_capabilities(environ_cp) + + set_tf_cuda_clang(environ_cp) + if environ_cp.get('TF_CUDA_CLANG') == '1': + # Set up which clang we should use as the cuda / host compiler. + set_clang_cuda_compiler_path(environ_cp) + else: + # Set up which gcc nvcc should use as the host compiler + # No need to set this on Windows + if not is_windows(): + set_gcc_host_compiler_path(environ_cp) + set_other_cuda_vars(environ_cp) + + set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) + if environ_cp.get('TF_NEED_MPI') == '1': + set_mpi_home(environ_cp) + set_other_mpi_vars(environ_cp) + + set_cc_opt_flags(environ_cp) + set_mkl() + + +if __name__ == '__main__': + main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6450b2ad878b57191ae3b12e7e39213ac168eef6..a162bcf452515315c621c6c72740413bf21f849d 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -39,7 +39,7 @@ config_setting( config_setting( name = "android_armeabi", values = { - "cc_target_os": "android", + "crosstool_top": "//external:android/crosstool", "cpu": "armeabi", }, visibility = ["//visibility:public"], @@ -63,6 +63,24 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "android_mips", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "mips", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_mips64", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "mips64", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "darwin", values = {"cpu": "darwin"}, @@ -178,7 +196,10 @@ config_setting( package_group( name = "internal", - packages = ["//tensorflow/..."], + packages = [ + "//learning/protonn/llgtm/...", + "//tensorflow/...", + ], ) filegroup( @@ -216,9 +237,12 @@ filegroup( "//tensorflow/compiler/jit/kernels:all_files", "//tensorflow/compiler/jit/legacy_flags:all_files", "//tensorflow/compiler/jit/ops:all_files", + "//tensorflow/compiler/plugin/executor:all_files", "//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tf2xla:all_files", + "//tensorflow/compiler/tf2xla/cc:all_files", "//tensorflow/compiler/tf2xla/kernels:all_files", + "//tensorflow/compiler/tf2xla/ops:all_files", "//tensorflow/compiler/xla:all_files", "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", @@ -238,6 +262,7 @@ filegroup( "//tensorflow/contrib/batching/util:all_files", "//tensorflow/contrib/bayesflow:all_files", "//tensorflow/contrib/boosted_trees:all_files", + "//tensorflow/contrib/boosted_trees/estimator_batch:all_files", "//tensorflow/contrib/boosted_trees/lib:all_files", "//tensorflow/contrib/boosted_trees/proto:all_files", "//tensorflow/contrib/boosted_trees/resources:all_files", @@ -253,13 +278,14 @@ filegroup( "//tensorflow/contrib/data/python/kernel_tests:all_files", "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/data/python/util:all_files", - "//tensorflow/contrib/decision_trees:all_files", + "//tensorflow/contrib/decision_trees/proto:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", "//tensorflow/contrib/ffmpeg:all_files", "//tensorflow/contrib/ffmpeg/default:all_files", "//tensorflow/contrib/framework:all_files", + "//tensorflow/contrib/fused_conv:all_files", "//tensorflow/contrib/graph_editor:all_files", "//tensorflow/contrib/grid_rnn:all_files", "//tensorflow/contrib/hooks:all_files", @@ -284,6 +310,9 @@ filegroup( "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", + "//tensorflow/contrib/predictor:all_files", + "//tensorflow/contrib/remote_fused_graph/pylib:all_files", + "//tensorflow/contrib/resampler:all_files", "//tensorflow/contrib/rnn:all_files", "//tensorflow/contrib/saved_model:all_files", "//tensorflow/contrib/saved_model/cc/saved_model:all_files", @@ -302,10 +331,17 @@ filegroup( "//tensorflow/contrib/stateless:all_files", "//tensorflow/contrib/tensor_forest:all_files", "//tensorflow/contrib/tensor_forest/hybrid:all_files", + "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", + "//tensorflow/contrib/tensor_forest/proto:all_files", "//tensorflow/contrib/tensorboard:all_files", "//tensorflow/contrib/testing:all_files", "//tensorflow/contrib/text:all_files", - "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files", + "//tensorflow/contrib/tfprof:all_files", + "//tensorflow/contrib/timeseries:all_files", + "//tensorflow/contrib/timeseries/examples:all_files", + "//tensorflow/contrib/timeseries/python/timeseries:all_files", + "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files", + "//tensorflow/contrib/tpu:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", @@ -321,12 +357,16 @@ filegroup( "//tensorflow/core/grappler/optimizers:all_files", "//tensorflow/core/grappler/utils:all_files", "//tensorflow/core/kernels:all_files", + "//tensorflow/core/kernels/fuzzing:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/kernels/neon:all_files", "//tensorflow/core/ops/compat:all_files", "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", "//tensorflow/core/platform/hadoop:all_files", + "//tensorflow/core/profiler:all_files", + "//tensorflow/core/profiler/internal:all_files", + "//tensorflow/core/profiler/internal/advisor:all_files", "//tensorflow/core/util/ctc:all_files", "//tensorflow/core/util/tensor_bundle:all_files", "//tensorflow/examples/android:all_files", @@ -351,72 +391,10 @@ filegroup( "//tensorflow/python/kernel_tests:all_files", "//tensorflow/python/kernel_tests/distributions:all_files", "//tensorflow/python/ops/distributions:all_files", + "//tensorflow/python/profiler:all_files", + "//tensorflow/python/profiler/internal:all_files", "//tensorflow/python/saved_model:all_files", "//tensorflow/python/tools:all_files", - "//tensorflow/tensorboard:all_files", - "//tensorflow/tensorboard/backend:all_files", - "//tensorflow/tensorboard/backend/event_processing:all_files", - "//tensorflow/tensorboard/components:all_files", - "//tensorflow/tensorboard/components/tf_audio_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_audio_dashboard/test:all_files", - "//tensorflow/tensorboard/components/tf_backend:all_files", - "//tensorflow/tensorboard/components/tf_backend/test:all_files", - "//tensorflow/tensorboard/components/tf_color_scale:all_files", - "//tensorflow/tensorboard/components/tf_color_scale/test:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common/test:all_files", - "//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_globals:all_files", - "//tensorflow/tensorboard/components/tf_graph:all_files", - "//tensorflow/tensorboard/components/tf_graph/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_app:all_files", - "//tensorflow/tensorboard/components/tf_graph_app/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_board:all_files", - "//tensorflow/tensorboard/components/tf_graph_board/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_common:all_files", - "//tensorflow/tensorboard/components/tf_graph_controls:all_files", - "//tensorflow/tensorboard/components/tf_graph_controls/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_graph_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card:all_files", - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_info:all_files", - "//tensorflow/tensorboard/components/tf_graph_info/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_loader:all_files", - "//tensorflow/tensorboard/components/tf_graph_loader/demo:all_files", - "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_image_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_imports:all_files", - "//tensorflow/tensorboard/components/tf_option_selector:all_files", - "//tensorflow/tensorboard/components/tf_profile_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_profile_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_runs_selector:all_files", - "//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_storage:all_files", - "//tensorflow/tensorboard/components/tf_storage/test:all_files", - "//tensorflow/tensorboard/components/tf_tensorboard:all_files", - "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_trace_viewer:all_files", - "//tensorflow/tensorboard/components/vz_distribution_chart:all_files", - "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files", - "//tensorflow/tensorboard/components/vz_line_chart:all_files", - "//tensorflow/tensorboard/components/vz_projector:all_files", - "//tensorflow/tensorboard/components/vz_projector/test:all_files", - "//tensorflow/tensorboard/components/vz_sorting:all_files", - "//tensorflow/tensorboard/components/vz_sorting/test:all_files", - "//tensorflow/tensorboard/demo:all_files", - "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", - "//tensorflow/tensorboard/plugins:all_files", - "//tensorflow/tensorboard/plugins/audio:all_files", - "//tensorflow/tensorboard/plugins/distributions:all_files", - "//tensorflow/tensorboard/plugins/graphs:all_files", - "//tensorflow/tensorboard/plugins/histograms:all_files", - "//tensorflow/tensorboard/plugins/images:all_files", - "//tensorflow/tensorboard/plugins/projector:all_files", - "//tensorflow/tensorboard/plugins/scalars:all_files", - "//tensorflow/tensorboard/plugins/text:all_files", - "//tensorflow/tensorboard/scripts:all_files", "//tensorflow/tools/api/golden:all_files", "//tensorflow/tools/api/lib:all_files", "//tensorflow/tools/api/tests:all_files", @@ -427,12 +405,10 @@ filegroup( "//tensorflow/tools/docker/notebooks:all_files", "//tensorflow/tools/docs:all_files", "//tensorflow/tools/git:all_files", + "//tensorflow/tools/mlpbtxt:all_files", "//tensorflow/tools/proto_text:all_files", "//tensorflow/tools/quantization:all_files", "//tensorflow/tools/test:all_files", - "//tensorflow/tools/tfprof:all_files", - "//tensorflow/tools/tfprof/internal:all_files", - "//tensorflow/tools/tfprof/internal/advisor:all_files", "//tensorflow/user_ops:all_files", "//third_party/hadoop:all_files", "//third_party/sycl:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 3ab4e8efcdb5b05cf8922edd302e7cbf3a3597f1..507b2fe1f1420661585b573da541e33e39e5adf5 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -62,6 +62,7 @@ tf_cuda_library( "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", ], }), @@ -102,6 +103,19 @@ tf_cuda_library( # ----------------------------------------------------------------------------- # Tests +tf_cuda_library( + name = "c_test_util", + testonly = 1, + srcs = ["c_test_util.cc"], + hdrs = ["c_test_util.h"], + deps = [ + ":c_api", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "c_api_test", size = "small", @@ -119,6 +133,7 @@ tf_cc_test( # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", + ":c_test_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", @@ -138,11 +153,38 @@ tf_cc_test( ], ) +tf_cc_test( + name = "while_loop_test", + size = "small", + srcs = ["while_loop_test.cc"], + deps = [ + ":c_api", + ":c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_custom_op_library( name = "test_op.so", srcs = ["test_op.cc"], ) +# ----------------------------------------------------------------------------- +# Python API target + +tf_cuda_library( + name = "python_api", + srcs = ["python_api.cc"], + hdrs = ["python_api.h"], + visibility = ["//tensorflow/python:__pkg__"], + deps = [ + ":c_api", + ":c_api_internal", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 77faa475ed47990a4dcee0e1ca0925af0c1643f9..371264ef6c20dbaa8263668eb526d49bb25c50c0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -28,12 +28,15 @@ limitations under the License. #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -53,21 +56,16 @@ limitations under the License. // The implementation below is at the top level instead of the // brain namespace because we are defining 'extern "C"' functions. -using tensorflow::error::Code; -using tensorflow::errors::InvalidArgument; -using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; using tensorflow::AllocationDescription; using tensorflow::DataType; using tensorflow::Graph; using tensorflow::GraphDef; -using tensorflow::mutex_lock; using tensorflow::NameRangeMap; using tensorflow::NameRangesForNode; using tensorflow::NewSession; using tensorflow::Node; -using tensorflow::NodeDef; using tensorflow::NodeBuilder; +using tensorflow::NodeDef; using tensorflow::OpDef; using tensorflow::OpRegistry; using tensorflow::PartialTensorShape; @@ -80,6 +78,11 @@ using tensorflow::TensorBuffer; using tensorflow::TensorId; using tensorflow::TensorShape; using tensorflow::TensorShapeProto; +using tensorflow::error::Code; +using tensorflow::errors::InvalidArgument; +using tensorflow::gtl::ArraySlice; +using tensorflow::mutex_lock; +using tensorflow::strings::StrCat; extern "C" { @@ -163,7 +166,7 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, if (out->data != nullptr) { return InvalidArgument("Passing non-empty TF_Buffer is invalid."); } - const auto proto_size = in.ByteSize(); + const auto proto_size = in.ByteSizeLong(); void* buf = tensorflow::port::Malloc(proto_size); in.SerializeToArray(buf, proto_size); out->data = buf; @@ -255,24 +258,27 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst, return sz; } -size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, - size_t* dst_len, TF_Status* status) { +static Status TF_StringDecode_Impl(const char* src, size_t src_len, + const char** dst, size_t* dst_len) { tensorflow::uint64 len64 = 0; const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64); if (p == nullptr) { - status->status = - InvalidArgument("invalid string encoding or truncated src buffer"); - return 0; + return InvalidArgument("invalid string encoding or truncated src buffer"); } if (len64 > std::numeric_limits::max()) { - status->status = - InvalidArgument("encoded string is ", len64, - "-bytes, which is too large for this architecture"); - return 0; + return InvalidArgument("encoded string is ", len64, + "-bytes, which is too large for this architecture"); } *dst = p; *dst_len = static_cast(len64); - return static_cast(p - src) + *dst_len; + return Status::OK(); +} + +size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, + size_t* dst_len, TF_Status* status) { + status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len); + if (!status->status.ok()) return 0; + return static_cast(*dst - src) + *dst_len; } size_t TF_StringEncodedSize(size_t len) { @@ -388,16 +394,20 @@ void TF_Reset(const TF_SessionOptions* opt, const char** containers, namespace tensorflow { -// Non-static for testing. -bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status) { +Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { + if (src->dtype != TF_STRING) { + *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer); + return Status::OK(); + } + // TF_STRING tensors require copying since Tensor class expects a sequence of + // string objects. const tensorflow::int64 num_elements = src->shape.num_elements(); const char* input = reinterpret_cast(TF_TensorData(src)); const size_t src_size = TF_TensorByteSize(src); if (static_cast(src_size / sizeof(tensorflow::uint64)) < num_elements) { - status->status = InvalidArgument( + return InvalidArgument( "Malformed TF_STRING tensor; too short to hold number of elements"); - return false; } const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; const char* limit = input + src_size; @@ -408,24 +418,30 @@ bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status) { tensorflow::uint64 offset = reinterpret_cast(input)[i]; if (static_cast(offset) >= (limit - data_start)) { - status->status = InvalidArgument("Malformed TF_STRING tensor; element ", - i, " out of range"); - return false; + return InvalidArgument("Malformed TF_STRING tensor; element ", i, + " out of range"); } size_t len; const char* p; const char* srcp = data_start + offset; - TF_StringDecode(srcp, limit - srcp, &p, &len, status); - if (!status->status.ok()) { - return false; - } + Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len); + if (!status.ok()) return status; dstarray(i).assign(p, len); } - return true; + return Status::OK(); } // Non-static for testing. -TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) { +TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src) { + if (src.dtype() != DT_STRING) { + TensorBuffer* buf = TensorCApi::Buffer(src); + buf->Ref(); + return new TF_Tensor{static_cast(src.dtype()), src.shape(), + buf}; + } + // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly + // encoded sequence of strings. + // Compute bytes needed for encoding. size_t size = 0; const auto& srcarray = src.flat(); @@ -466,15 +482,6 @@ TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) { dimvec.size(), base, size, DeleteArray, base); } -class TensorCApi { - public: - static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } - static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, - TensorBuffer* buf) { - return Tensor(static_cast(type), shape, buf); - } -}; - // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to // result in a zero-sized tensor. static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { @@ -513,16 +520,8 @@ static bool TF_Run_Inputs( TF_Status* status) { const int ninputs = input_pairs->size(); for (int i = 0; i < ninputs; ++i) { - TF_Tensor* src = c_inputs[i]; - if (c_inputs[i]->dtype != TF_STRING) { - (*input_pairs)[i].second = tensorflow::TensorCApi::MakeTensor( - src->dtype, src->shape, src->buffer); - } else if (!tensorflow::TF_Tensor_DecodeStrings( - src, &(*input_pairs)[i].second, status)) { - // TF_STRING tensors require copying since Tensor class expects - // a sequence of string objects. - return false; - } + status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); + if (!status->status.ok()) return false; } return true; } @@ -580,15 +579,7 @@ static void TF_Run_Helper( static_cast(src.dtype()), src.shape()); continue; } - if (src.dtype() != tensorflow::DT_STRING) { - // Share the underlying buffer. - TensorBuffer* buf = tensorflow::TensorCApi::Buffer(src); - buf->Ref(); - c_outputs[i] = new TF_Tensor{static_cast(src.dtype()), - src.shape(), buf}; - } else { - c_outputs[i] = tensorflow::TF_Tensor_EncodeStrings(src); - } + c_outputs[i] = TF_TensorFromTensor(src); } } @@ -628,7 +619,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, // Target nodes const char** c_target_oper_names, int ntargets, const char** handle, TF_Status* status) { - status->status = Status::OK(); + *handle = nullptr; std::vector input_names(ninputs); std::vector output_names(noutputs); @@ -643,16 +634,12 @@ void TF_PRunSetup(TF_DeprecatedSession* s, target_oper_names[i] = c_target_oper_names[i]; } tensorflow::string new_handle; - Status result; - result = s->session->PRunSetup(input_names, output_names, target_oper_names, - &new_handle); - if (result.ok()) { + status->status = s->session->PRunSetup(input_names, output_names, + target_oper_names, &new_handle); + if (status->status.ok()) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; - } else { - *handle = nullptr; - status->status = result; } } @@ -1072,20 +1059,9 @@ void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, TF_Tensor* value, TF_Status* status) { - status->status = Status::OK(); Tensor t; - bool ok = true; - - if (value->dtype != TF_STRING) { - t = tensorflow::TensorCApi::MakeTensor(value->dtype, value->shape, - value->buffer); - } else { - // TF_STRING tensors require copying since Tensor class expects - // a sequence of string objects. - ok = tensorflow::TF_Tensor_DecodeStrings(value, &t, status); - } - - if (ok) desc->node_builder.Attr(attr_name, t); + status->status = TF_TensorToTensor(value, &t); + if (status->status.ok()) desc->node_builder.Attr(attr_name, t); } void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, @@ -1094,21 +1070,14 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, status->status = Status::OK(); std::vector t; t.reserve(num_values); - bool ok = true; - for (int i = 0; i < num_values && ok; ++i) { - if (values[i]->dtype != TF_STRING) { - t.emplace_back(tensorflow::TensorCApi::MakeTensor( - values[i]->dtype, values[i]->shape, values[i]->buffer)); - } else { - t.emplace_back(::tensorflow::DT_STRING); - // TF_STRING tensors require copying since Tensor class expects - // a sequence of string objects. - ok = tensorflow::TF_Tensor_DecodeStrings(values[i], &t.back(), status); - } + for (int i = 0; i < num_values && status->status.ok(); ++i) { + Tensor v; + status->status = TF_TensorToTensor(values[i], &v); + t.emplace_back(v); } - if (ok) desc->node_builder.Attr(attr_name, t); + if (status->status.ok()) desc->node_builder.Attr(attr_name, t); } void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, @@ -1565,9 +1534,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, Tensor t; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); if (!status->status.ok()) return; - *value = new TF_Tensor{static_cast(t.dtype()), t.shape(), - tensorflow::TensorCApi::Buffer(t)}; - (*value)->buffer->Ref(); + *value = TF_TensorFromTensor(t); } void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, @@ -1578,10 +1545,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, if (!status->status.ok()) return; const auto len = std::min(max_values, static_cast(ts.size())); for (int i = 0; i < len; ++i) { - const Tensor& t = ts[i]; - values[i] = new TF_Tensor{static_cast(t.dtype()), t.shape(), - tensorflow::TensorCApi::Buffer(t)}; - values[i]->buffer->Ref(); + values[i] = TF_TensorFromTensor(ts[i]); } } @@ -1600,6 +1564,14 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, // TF_Graph functions --------------------------------------------------------- +TF_Graph::TF_Graph() + : graph(tensorflow::OpRegistry::Global()), + refiner(graph.versions().producer(), graph.op_registry()), + num_sessions(0), + delete_requested(false), + parent(nullptr), + parent_inputs(nullptr) {} + TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { @@ -2119,7 +2091,8 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, const int max_node_id_before = g->graph.num_node_ids(); tensorflow::Scope scope = - NewInternalScope(&g->graph, &status->status, &g->refiner); + NewInternalScope(&g->graph, &status->status, &g->refiner) + .NewSubScope("gradients"); if (dx != nullptr) { std::vector dx_arg; @@ -2326,6 +2299,8 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, int ninputs, const TF_Output* outputs, int noutputs, const TF_Operation* const* target_opers, int ntargets, const char** handle, TF_Status* status) { + *handle = nullptr; + if (!ExtendSessionGraphHelper(session, status)) { return; } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 15139a47acf4b5bcf7a6b6fd77de5834f3f9189c..46758408c44ea4170abd4282294b77b07c762389 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -117,6 +117,7 @@ typedef enum TF_DataType { TF_COMPLEX128 = 18, // Double-precision complex TF_HALF = 19, TF_RESOURCE = 20, + TF_VARIANT = 21, } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding @@ -1101,8 +1102,7 @@ TF_CAPI_EXPORT extern void TF_SessionRun( // needed. // // On failure, out_status contains a tensorflow::Status with an error -// message. -// NOTE: This is EXPERIMENTAL and subject to change. +// message. *handle is set to nullptr. TF_CAPI_EXPORT extern void TF_SessionPRunSetup( TF_Session*, // Input names @@ -1118,7 +1118,6 @@ TF_CAPI_EXPORT extern void TF_SessionPRunSetup( // Continue to run the graph with additional feeds and fetches. The // execution state is uniquely identified by the handle. -// NOTE: This is EXPERIMENTAL and subject to change. TF_CAPI_EXPORT extern void TF_SessionPRun( TF_Session*, const char* handle, // Input tensors diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index f17ac26ad9665d7ea8cc1ef566cad81bba712b62..d077ad264b198d7f4b29dbd58808b09c8239a28e 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_C_C_API_INTERNAL_H_ +#define TENSORFLOW_C_C_API_INTERNAL_H_ + #include "tensorflow/c/c_api.h" #include @@ -56,13 +59,8 @@ struct TF_Library { }; struct TF_Graph { - TF_Graph() - : graph(tensorflow::OpRegistry::Global()), - refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), - delete_requested(false), - parent(nullptr), - parent_inputs(nullptr) {} + TF_Graph(); + tensorflow::mutex mu; tensorflow::Graph graph GUARDED_BY(mu); @@ -117,3 +115,18 @@ struct TF_ImportGraphDefOptions { struct TF_DeviceList { std::vector response; }; + +namespace tensorflow { + +class TensorCApi { + public: + static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } + static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, + TensorBuffer* buf) { + return Tensor(static_cast(type), shape, buf); + } +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 04540bd793dab34c2f707e9e995defe7b4e15858..25b6cbd8e7ad4b92b8ecbafe87c040190ad18b58 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -16,9 +16,12 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include +#include #include #include #include + +#include "tensorflow/c/c_test_util.h" #include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" @@ -41,24 +44,13 @@ limitations under the License. #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/util/equal_graph_def.h" -using tensorflow::int32; -using tensorflow::string; -using tensorflow::GraphDef; -using tensorflow::NodeDef; -using tensorflow::Tensor; -using tensorflow::TensorShape; - namespace tensorflow { -bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status); -TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src); -} // namespace tensorflow +TF_Tensor* TF_TensorFromTensor(const Tensor& src); +Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { -typedef std::unique_ptr - unique_tensor_ptr; - -TEST(CAPI, Version) { EXPECT_NE("", string(TF_Version())); } +TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); } TEST(CAPI, Status) { TF_Status* s = TF_NewStatus(); @@ -70,7 +62,7 @@ TEST(CAPI, Status) { TF_DeleteStatus(s); } -static void Deallocator(void* data, size_t, void* arg) { +void Deallocator(void* data, size_t, void* arg) { tensorflow::cpu_allocator()->DeallocateRaw(data); *reinterpret_cast(arg) = true; } @@ -143,7 +135,7 @@ TEST(CAPI, LibraryLoadFunctions) { TF_DeleteLibraryHandle(lib); } -static void TestEncodeDecode(int line, const std::vector& data) { +void TestEncodeDecode(int line, const std::vector& data) { const tensorflow::int64 n = data.size(); for (const std::vector& dims : std::vector>{ @@ -153,19 +145,16 @@ static void TestEncodeDecode(int line, const std::vector& data) { for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { src.flat()(i) = data[i]; } - TF_Tensor* dst = TF_Tensor_EncodeStrings(src); + TF_Tensor* dst = TF_TensorFromTensor(src); // Convert back to a C++ Tensor and ensure we get expected output. - TF_Status* status = TF_NewStatus(); Tensor output; - ASSERT_TRUE(TF_Tensor_DecodeStrings(dst, &output, status)) << line; - ASSERT_EQ(TF_OK, TF_GetCode(status)) << line; + ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line; ASSERT_EQ(src.NumElements(), output.NumElements()) << line; for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { ASSERT_EQ(data[i], output.flat()(i)) << line; } - TF_DeleteStatus(status); TF_DeleteTensor(dst); } } @@ -275,194 +264,6 @@ TEST(CAPI, GetAllOpList) { TF_DeleteBuffer(buf); } -static void Int32Deallocator(void* data, size_t, void* arg) { - delete[] static_cast(data); -} - -// Create a tensor with values of type TF_INT8 provided by `values`. -static TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, - const char* values) { - int64_t num_values = 1; - for (int i = 0; i < num_dims; ++i) { - num_values *= dims[i]; - } - TF_Tensor* t = - TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); - memcpy(TF_TensorData(t), values, sizeof(char) * num_values); - return t; -} - -static TF_Tensor* Int32Tensor(int32 v) { - const int num_bytes = sizeof(int32); - int32* values = new int32[1]; - values[0] = v; - return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes, - &Int32Deallocator, nullptr); -} - -TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, - const char* name = "feed") { - TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); - TF_SetAttrType(desc, "dtype", TF_INT32); - return TF_FinishOperation(desc, s); -} - -TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, - const char* name = "const") { - TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); - TF_SetAttrTensor(desc, "value", t, s); - if (TF_GetCode(s) != TF_OK) return nullptr; - TF_SetAttrType(desc, "dtype", TF_TensorType(t)); - return TF_FinishOperation(desc, s); -} - -TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s, - const char* name = "scalar") { - unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); - return Const(tensor.get(), graph, s, name); -} - -TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s, const char* name = "add") { - TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); - TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; - TF_AddInputList(desc, add_inputs, 2); - return TF_FinishOperation(desc, s); -} - -TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, - const char* name = "add") { - TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); - TF_Output inputs[2] = {l, r}; - TF_AddInputList(desc, inputs, 2); - return TF_FinishOperation(desc, s); -} - -TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { - TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg"); - TF_Output neg_input = {n, 0}; - TF_AddInput(desc, neg_input); - return TF_FinishOperation(desc, s); -} - -TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, - TF_Status* s) { - TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than"); - TF_AddInput(desc, l); - TF_AddInput(desc, r); - return TF_FinishOperation(desc, s); -} - -bool IsPlaceholder(const NodeDef& node_def) { - if (node_def.op() != "Placeholder" || node_def.name() != "feed") { - return false; - } - bool found_dtype = false; - bool found_shape = false; - for (const auto& attr : node_def.attr()) { - if (attr.first == "dtype") { - if (attr.second.type() == tensorflow::DT_INT32) { - found_dtype = true; - } else { - return false; - } - } else if (attr.first == "shape") { - found_shape = true; - } - } - return found_dtype && found_shape; -} - -bool IsScalarConst(const NodeDef& node_def, int v) { - if (node_def.op() != "Const" || node_def.name() != "scalar") { - return false; - } - bool found_dtype = false; - bool found_value = false; - for (const auto& attr : node_def.attr()) { - if (attr.first == "dtype") { - if (attr.second.type() == tensorflow::DT_INT32) { - found_dtype = true; - } else { - return false; - } - } else if (attr.first == "value") { - if (attr.second.has_tensor() && - attr.second.tensor().int_val_size() == 1 && - attr.second.tensor().int_val(0) == v) { - found_value = true; - } else { - return false; - } - } - } - return found_dtype && found_value; -} - -bool IsAddN(const NodeDef& node_def, int n) { - if (node_def.op() != "AddN" || node_def.name() != "add" || - node_def.input_size() != n) { - return false; - } - bool found_t = false; - bool found_n = false; - for (const auto& attr : node_def.attr()) { - if (attr.first == "T") { - if (attr.second.type() == tensorflow::DT_INT32) { - found_t = true; - } else { - return false; - } - } else if (attr.first == "N") { - if (attr.second.i() == n) { - found_n = true; - } else { - return false; - } - } - } - return found_t && found_n; -} - -bool IsNeg(const NodeDef& node_def, const string& input) { - return node_def.op() == "Neg" && node_def.name() == "neg" && - node_def.input_size() == 1 && node_def.input(0) == input; -} - -bool GetGraphDef(TF_Graph* graph, GraphDef* graph_def) { - TF_Status* s = TF_NewStatus(); - TF_Buffer* buffer = TF_NewBuffer(); - TF_GraphToGraphDef(graph, buffer, s); - bool ret = TF_GetCode(s) == TF_OK; - EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length); - TF_DeleteBuffer(buffer); - TF_DeleteStatus(s); - return ret; -} - -bool GetNodeDef(TF_Operation* oper, NodeDef* node_def) { - TF_Status* s = TF_NewStatus(); - TF_Buffer* buffer = TF_NewBuffer(); - TF_OperationToNodeDef(oper, buffer, s); - bool ret = TF_GetCode(s) == TF_OK; - EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length); - TF_DeleteBuffer(buffer); - TF_DeleteStatus(s); - return ret; -} - -bool GetAttrValue(TF_Operation* oper, const char* attr_name, - tensorflow::AttrValue* attr_value, TF_Status* s) { - TF_Buffer* buffer = TF_NewBuffer(); - TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); - bool ret = TF_GetCode(s) == TF_OK; - if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length); - TF_DeleteBuffer(buffer); - return ret; -} - TEST(CAPI, SetShape) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -879,114 +680,6 @@ TEST(CAPI, ImportGraphDef) { TF_DeleteStatus(s); } -class CSession { - public: - CSession(TF_Graph* graph, TF_Status* s) { - TF_SessionOptions* opts = TF_NewSessionOptions(); - session_ = TF_NewSession(graph, opts, s); - TF_DeleteSessionOptions(opts); - } - - explicit CSession(TF_Session* session) : session_(session) {} - - ~CSession() { - TF_Status* s = TF_NewStatus(); - CloseAndDelete(s); - EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_DeleteStatus(s); - } - - void SetInputs(std::vector> inputs) { - DeleteInputValues(); - inputs_.clear(); - for (const auto& p : inputs) { - inputs_.emplace_back(TF_Output{p.first, 0}); - input_values_.emplace_back(p.second); - } - } - - void SetOutputs(std::initializer_list outputs) { - ResetOutputValues(); - outputs_.clear(); - for (TF_Operation* o : outputs) { - outputs_.emplace_back(TF_Output{o, 0}); - } - } - - void SetOutputs(const std::vector& outputs) { - ResetOutputValues(); - outputs_ = outputs; - } - - void SetTargets(std::initializer_list targets) { - targets_.clear(); - for (TF_Operation* t : targets) { - targets_.emplace_back(t); - } - } - - void Run(TF_Status* s) { - if (inputs_.size() != input_values_.size()) { - ADD_FAILURE() << "Call SetInputs() before Run()"; - return; - } - ResetOutputValues(); - output_values_.resize(outputs_.size(), nullptr); - - const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0]; - TF_Tensor* const* input_values_ptr = - input_values_.empty() ? nullptr : &input_values_[0]; - - const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0]; - TF_Tensor** output_values_ptr = - output_values_.empty() ? nullptr : &output_values_[0]; - - TF_Operation* const* targets_ptr = - targets_.empty() ? nullptr : &targets_[0]; - - TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, - inputs_.size(), outputs_ptr, output_values_ptr, - outputs_.size(), targets_ptr, targets_.size(), nullptr, s); - - DeleteInputValues(); - } - - void CloseAndDelete(TF_Status* s) { - DeleteInputValues(); - ResetOutputValues(); - if (session_ != nullptr) { - TF_CloseSession(session_, s); - EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_DeleteSession(session_, s); - session_ = nullptr; - } - } - - TF_Tensor* output_tensor(int i) { return output_values_[i]; } - - private: - void DeleteInputValues() { - for (int i = 0; i < input_values_.size(); ++i) { - TF_DeleteTensor(input_values_[i]); - } - input_values_.clear(); - } - - void ResetOutputValues() { - for (int i = 0; i < output_values_.size(); ++i) { - if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]); - } - output_values_.clear(); - } - - TF_Session* session_; - std::vector inputs_; - std::vector input_values_; - std::vector outputs_; - std::vector output_values_; - std::vector targets_; -}; - TEST(CAPI, Session) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -1221,7 +914,7 @@ TEST(CAPI, SavedModel) { TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); - csession.SetInputs({{input_op, TF_Tensor_EncodeStrings(input)}}); + csession.SetInputs({{input_op, TF_TensorFromTensor(input)}}); const tensorflow::string output_op_name = tensorflow::ParseTensorName(output_name).first.ToString(); @@ -1272,308 +965,6 @@ TEST(CAPI, SavedModelNullArgsAreValid) { TF_DeleteStatus(s); } -class CApiWhileLoopTest : public ::testing::Test { - protected: - CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {} - - ~CApiWhileLoopTest() override { - TF_DeleteGraph(graph_); - TF_DeleteStatus(s_); - } - - void Init(int ninputs) { - DCHECK(inputs_.empty()); - DCHECK_GT(ninputs, 0); - - for (int i = 0; i < ninputs; ++i) { - TF_Operation* placeholder = Placeholder( - graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str()); - DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - inputs_.push_back({placeholder, 0}); - } - - original_graph_description_ = GraphDebugString(); - - params_.reset(new TF_WhileParams( - TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_))); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - ASSERT_EQ(original_graph_description_, GraphDebugString()) - << "TF_NewWhile() altered graph"; - - params_->name = "test_loop"; - - // Initialize outputs_ so we can easily detect errors/bugs - outputs_.resize(ninputs, {nullptr, -1}); - } - - void ExpectOK() { - TF_FinishWhile(params_.get(), s_, &outputs_[0]); - EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - } - - void ExpectError(TF_Code expected_code, const string& expected_msg) { - TF_FinishWhile(params_.get(), s_, &outputs_[0]); - EXPECT_EQ(expected_code, TF_GetCode(s_)); - EXPECT_EQ(expected_msg, TF_Message(s_)); - // TODO(skyewm): this assert is currently broken. Fix or remove guarantee. - // ASSERT_EQ(original_graph_description_, GraphDebugString()) << - // "TF_FinishWhile() altered graph on error"; - } - - void Run(std::initializer_list input_values) { - DCHECK_EQ(inputs_.size(), input_values.size()); - std::vector> inputs(inputs_.size()); - int i = 0; - for (int v : input_values) { - inputs[i] = {inputs_[i].oper, Int32Tensor(v)}; - ++i; - } - csession_.reset(new CSession(graph_, s_)); - csession_->SetInputs(inputs); - csession_->SetOutputs(outputs_); - csession_->Run(s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - } - - void ExpectOutputValue(int idx, int expected_value) { - TF_Tensor* out = csession_->output_tensor(idx); - ASSERT_TRUE(out != nullptr); - EXPECT_EQ(TF_INT32, TF_TensorType(out)); - EXPECT_EQ(0, TF_NumDims(out)); - ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); - int32* data = static_cast(TF_TensorData(out)); - EXPECT_EQ(expected_value, *data); - } - - // Create a valid conditional graph. Useful for testing unrelated errors. - void CreateCondGraph() { - TF_Operation* one = ScalarConst(1, params_->cond_graph, s_); - TF_Operation* less_than = - LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_); - DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - params_->cond_output = {less_than, 0}; - } - - string GraphDebugString() const { - TF_Buffer* buf = TF_NewBuffer(); - TF_GraphToGraphDef(graph_, buf, s_); - DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - GraphDef def; - bool success = def.ParseFromArray(buf->data, buf->length); - DCHECK(success); - TF_DeleteBuffer(buf); - return def.DebugString(); - } - - TF_Status* s_; - TF_Graph* graph_; - std::vector inputs_; // The inputs to the while loop - std::vector outputs_; // The final outputs of the while loop - std::unique_ptr params_; - std::unique_ptr csession_; - - private: - // Used to verify that errors don't change graph_ - string original_graph_description_; -}; - -TEST_F(CApiWhileLoopTest, BasicLoop) { - Init(2); - - // Validate TF_WhileParams returned by TF_NewWhile() - EXPECT_TRUE(params_->body_graph != nullptr); - EXPECT_TRUE(params_->cond_graph != nullptr); - - EXPECT_EQ(params_->ninputs, 2); - - ASSERT_TRUE(params_->cond_inputs != nullptr); - ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr); - EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr); - - ASSERT_TRUE(params_->body_inputs != nullptr); - EXPECT_TRUE(params_->body_inputs[0].oper != nullptr); - EXPECT_TRUE(params_->body_inputs[1].oper != nullptr); - - ASSERT_TRUE(params_->body_outputs != nullptr); - - // Create loop: while (input1 < input2) input1 += input2 + 1 - TF_Operation* less_than = - LessThan(params_->cond_inputs[0], params_->cond_inputs[1], - params_->cond_graph, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - params_->cond_output = {less_than, 0}; - - TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1], - params_->body_graph, s_, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_Operation* one = ScalarConst(1, params_->body_graph, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - params_->body_outputs[0] = {add2, 0}; - params_->body_outputs[1] = params_->body_inputs[1]; - - // Finalize while loop - ExpectOK(); - - // Validate while loop outputs returned by TF_FinishWhile() - EXPECT_TRUE(outputs_[0].oper != nullptr); - EXPECT_GE(outputs_[0].index, 0); - EXPECT_TRUE(outputs_[1].oper != nullptr); - EXPECT_GE(outputs_[1].index, 0); - - // Run the graph - Run({-9, 2}); - ExpectOutputValue(0, 3); - ExpectOutputValue(1, 2); -} - -TEST_F(CApiWhileLoopTest, NestedLoop) { - Init(2); - // Create nested loop: - // while (input1 < 6) { - // inner_input1 = input1 - // while (inner_input1 < 3) { - // input2 += 1 - // inner_input1 += 2 - // } - // input1 += input2 - // } - // - // Expected execution with initial values input1 = input2 = 0: - // - // outer inner inner_ - // step# step# input1 input2 input1 - // ------------------------------------ - // 0 0 0 0 0 - // 0 1 0 1 2 - // 0 2 0 2 4 - // 0 - 2 2 - - // 1 0 2 2 2 - // 1 1 2 3 4 - // 1 - 5 3 - - // 2 0 5 3 5 - // 2 - 8 3 - - - // Create outer cond graph - TF_Operation* six = ScalarConst(6, params_->cond_graph, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_Operation* less_than = - LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - params_->cond_output = {less_than, 0}; - - // Create outer body graph - // Init inner graph - TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]}; - TF_WhileParams inner_params = - TF_NewWhile(params_->body_graph, inner_inputs, 2, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - inner_params.name = "inner_loop"; - - // Create inner cond graph - TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_Operation* inner_less_than = LessThan( - inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - inner_params.cond_output = {inner_less_than, 0}; - - // Create inner body graph - TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one"); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two"); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_Operation* input2_add = - Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - inner_params.body_outputs[1] = {input2_add, 0}; - - TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two, - inner_params.body_graph, s_, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - inner_params.body_outputs[0] = {inner_input1_add, 0}; - - // Finalize inner graph - TF_Output inner_outputs[2] = {{nullptr, -1}}; - TF_FinishWhile(&inner_params, s_, inner_outputs); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - - TF_Operation* input1_add = - Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_); - ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - params_->body_outputs[0] = {input1_add, 0}; - - params_->body_outputs[1] = inner_outputs[1]; - - // Finalize outer graph - ExpectOK(); - - // Check for a few expected nodes - const char* node_name = "test_loop/cond/scalar"; - EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); - node_name = "test_loop/body/add"; - EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); - node_name = "test_loop/body/inner_loop/body/one"; - EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); - node_name = "test_loop/body/inner_loop/cond/less_than"; - EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); - - // Run the graph - Run({0, 0}); - ExpectOutputValue(0, 8); - ExpectOutputValue(1, 3); -} - -TEST_F(CApiWhileLoopTest, BadCondOutput) { - Init(1); - params_->body_outputs[0] = params_->body_inputs[0]; - ExpectError(TF_INVALID_ARGUMENT, - "TF_WhileParams `cond_output` field isn't set"); -} - -TEST_F(CApiWhileLoopTest, BadBodyOutput) { - Init(1); - CreateCondGraph(); - ExpectError(TF_INVALID_ARGUMENT, - "TF_WhileParams `body_outputs[0]` field isn't set"); -} - -TEST_F(CApiWhileLoopTest, NullName) { - Init(1); - CreateCondGraph(); - params_->body_outputs[0] = params_->body_inputs[0]; - params_->name = nullptr; - ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null"); -} - -TEST_F(CApiWhileLoopTest, WrongGraph) { - Init(1); - CreateCondGraph(); - // Set body output to output from outer graph - params_->body_outputs[0] = inputs_[0]; - // TODO(skyewm): improve error message - ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); -} - -TEST_F(CApiWhileLoopTest, BadTypes) { - Init(1); - CreateCondGraph(); - // Op that has a float input + output - TF_OperationDescription* desc = TF_NewOperation( - params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op"); - TF_AddInput(desc, params_->body_inputs[0]); - TF_FinishOperation(desc, s_); - ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); - string msg(TF_Message(s_)); - EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while " - "building NodeDef 'float_op'"), - msg.npos); - TF_AbortWhile(params_.get()); -} - REGISTER_OP("TestOpWithNoGradient") .Input("x: T") .Output("y: T") @@ -1765,13 +1156,13 @@ class CApiGradientsTest : public ::testing::Test { const float const3_val[] = {1.0, 1.0, 1.0, 1.0}; const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs"); } else { - const3 = OnesLike(expected_graph_, s_, matmul, "OnesLike"); + const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike"); } - TF_Operation* matmul1 = - MatMul(expected_graph_, s_, const3, const1, "MatMul_1", false, true); - TF_Operation* matmul2 = - MatMul(expected_graph_, s_, const0, const3, "MatMul_2", true, false); + TF_Operation* matmul1 = MatMul(expected_graph_, s_, const3, const1, + "gradients/MatMul", false, true); + TF_Operation* matmul2 = MatMul(expected_graph_, s_, const0, const3, + "gradients/MatMul_1", true, false); expected_grad_outputs[0] = {matmul1, 0}; expected_grad_outputs[1] = {matmul2, 0}; } @@ -2241,6 +1632,39 @@ TEST_F(CApiAttributesTest, Tensor) { TF_DeleteTensor(value); } +TEST_F(CApiAttributesTest, StringTensor) { + // Create the string-Tensor "atttribute" value. + char encoded[] = { + 0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets + 1, // varint encoded string length + 'A', + }; + auto deallocator = [](void* data, size_t len, void* arg) {}; + unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &encoded[0], + sizeof(encoded), deallocator, nullptr), + TF_DeleteTensor); + + // Create a TF_Operation with the attribute t_in + auto desc = init("tensor"); + TF_SetAttrTensor(desc, "v", t_in.get(), s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + auto oper = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Fetch the attribute back. + EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1); + TF_Tensor* t_out = nullptr; + TF_OperationGetAttrTensor(oper, "v", &t_out, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + EXPECT_EQ(TF_STRING, TF_TensorType(t_out)); + EXPECT_EQ(0, TF_NumDims(t_out)); + ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out)); + EXPECT_EQ(0, memcmp(TF_TensorData(t_in.get()), TF_TensorData(t_out), + TF_TensorByteSize(t_out))); + TF_DeleteTensor(t_out); +} + TEST_F(CApiAttributesTest, TensorList) { const char tensor1[] = {5, 7}; const int64_t dims1[] = {1, 2}; @@ -2252,7 +1676,8 @@ TEST_F(CApiAttributesTest, TensorList) { auto desc = init("list(tensor)"); TF_Tensor* tmp[] = { - Int8Tensor(dims1, ndims1, tensor1), Int8Tensor(dims2, ndims2, tensor2), + Int8Tensor(dims1, ndims1, tensor1), + Int8Tensor(dims2, ndims2, tensor2), }; TF_SetAttrTensorList(desc, "v", tmp, TF_ARRAYSIZE(tmp), s_); for (int i = 0; i < TF_ARRAYSIZE(tmp); ++i) { @@ -2304,12 +1729,14 @@ TEST_F(CApiAttributesTest, Errors) { TF_OperationGetAttrString(oper, "v", nullptr, 0, s_); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); } + #undef EXPECT_TF_META +} // namespace +} // namespace tensorflow + // TODO(josh11b): Test: // * TF_SetDevice(desc, "/job:worker"); // * control inputs / outputs // * targets // * TF_DeleteGraph() before TF_DeleteSession() - -} // namespace diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..21603c1a07caf9e9fdcd53561a94fdf7756ec84d --- /dev/null +++ b/tensorflow/c/c_test_util.cc @@ -0,0 +1,304 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_test_util.h" + +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::GraphDef; +using tensorflow::NodeDef; + +static void Int32Deallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + +TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { + int64_t num_values = 1; + for (int i = 0; i < num_dims; ++i) { + num_values *= dims[i]; + } + TF_Tensor* t = + TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); + memcpy(TF_TensorData(t), values, sizeof(char) * num_values); + return t; +} + +TF_Tensor* Int32Tensor(int32_t v) { + const int num_bytes = sizeof(int32_t); + int32_t* values = new int32_t[1]; + values[0] = v; + return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes, + &Int32Deallocator, nullptr); +} + +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); + TF_SetAttrType(desc, "dtype", TF_INT32); + return TF_FinishOperation(desc, s); +} + +TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, + const char* name) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); + TF_SetAttrTensor(desc, "value", t, s); + if (TF_GetCode(s) != TF_OK) return nullptr; + TF_SetAttrType(desc, "dtype", TF_TensorType(t)); + return TF_FinishOperation(desc, s); +} + +TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + +TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); + TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; + TF_AddInputList(desc, add_inputs, 2); + return TF_FinishOperation(desc, s); +} + +TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, + const char* name) { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); + TF_Output inputs[2] = {l, r}; + TF_AddInputList(desc, inputs, 2); + return TF_FinishOperation(desc, s); +} + +TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg"); + TF_Output neg_input = {n, 0}; + TF_AddInput(desc, neg_input); + return TF_FinishOperation(desc, s); +} + +TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, + TF_Status* s) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than"); + TF_AddInput(desc, l); + TF_AddInput(desc, r); + return TF_FinishOperation(desc, s); +} + +bool IsPlaceholder(const tensorflow::NodeDef& node_def) { + if (node_def.op() != "Placeholder" || node_def.name() != "feed") { + return false; + } + bool found_dtype = false; + bool found_shape = false; + for (const auto& attr : node_def.attr()) { + if (attr.first == "dtype") { + if (attr.second.type() == tensorflow::DT_INT32) { + found_dtype = true; + } else { + return false; + } + } else if (attr.first == "shape") { + found_shape = true; + } + } + return found_dtype && found_shape; +} + +bool IsScalarConst(const tensorflow::NodeDef& node_def, int v) { + if (node_def.op() != "Const" || node_def.name() != "scalar") { + return false; + } + bool found_dtype = false; + bool found_value = false; + for (const auto& attr : node_def.attr()) { + if (attr.first == "dtype") { + if (attr.second.type() == tensorflow::DT_INT32) { + found_dtype = true; + } else { + return false; + } + } else if (attr.first == "value") { + if (attr.second.has_tensor() && + attr.second.tensor().int_val_size() == 1 && + attr.second.tensor().int_val(0) == v) { + found_value = true; + } else { + return false; + } + } + } + return found_dtype && found_value; +} + +bool IsAddN(const tensorflow::NodeDef& node_def, int n) { + if (node_def.op() != "AddN" || node_def.name() != "add" || + node_def.input_size() != n) { + return false; + } + bool found_t = false; + bool found_n = false; + for (const auto& attr : node_def.attr()) { + if (attr.first == "T") { + if (attr.second.type() == tensorflow::DT_INT32) { + found_t = true; + } else { + return false; + } + } else if (attr.first == "N") { + if (attr.second.i() == n) { + found_n = true; + } else { + return false; + } + } + } + return found_t && found_n; +} + +bool IsNeg(const tensorflow::NodeDef& node_def, const string& input) { + return node_def.op() == "Neg" && node_def.name() == "neg" && + node_def.input_size() == 1 && node_def.input(0) == input; +} + +bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def) { + TF_Status* s = TF_NewStatus(); + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphToGraphDef(graph, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length); + TF_DeleteBuffer(buffer); + TF_DeleteStatus(s); + return ret; +} + +bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) { + TF_Status* s = TF_NewStatus(); + TF_Buffer* buffer = TF_NewBuffer(); + TF_OperationToNodeDef(oper, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length); + TF_DeleteBuffer(buffer); + TF_DeleteStatus(s); + return ret; +} + +bool GetAttrValue(TF_Operation* oper, const char* attr_name, + tensorflow::AttrValue* attr_value, TF_Status* s) { + TF_Buffer* buffer = TF_NewBuffer(); + TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length); + TF_DeleteBuffer(buffer); + return ret; +} + +CSession::CSession(TF_Graph* graph, TF_Status* s) { + TF_SessionOptions* opts = TF_NewSessionOptions(); + session_ = TF_NewSession(graph, opts, s); + TF_DeleteSessionOptions(opts); +} + +CSession::CSession(TF_Session* session) : session_(session) {} + +CSession::~CSession() { + TF_Status* s = TF_NewStatus(); + CloseAndDelete(s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteStatus(s); +} + +void CSession::SetInputs( + std::vector> inputs) { + DeleteInputValues(); + inputs_.clear(); + for (const auto& p : inputs) { + inputs_.emplace_back(TF_Output{p.first, 0}); + input_values_.emplace_back(p.second); + } +} + +void CSession::SetOutputs(std::initializer_list outputs) { + ResetOutputValues(); + outputs_.clear(); + for (TF_Operation* o : outputs) { + outputs_.emplace_back(TF_Output{o, 0}); + } + output_values_.resize(outputs_.size()); +} + +void CSession::SetOutputs(const std::vector& outputs) { + ResetOutputValues(); + outputs_ = outputs; + output_values_.resize(outputs_.size()); +} + +void CSession::SetTargets(std::initializer_list targets) { + targets_.clear(); + for (TF_Operation* t : targets) { + targets_.emplace_back(t); + } +} + +void CSession::Run(TF_Status* s) { + if (inputs_.size() != input_values_.size()) { + ADD_FAILURE() << "Call SetInputs() before Run()"; + return; + } + ResetOutputValues(); + output_values_.resize(outputs_.size(), nullptr); + + const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0]; + TF_Tensor* const* input_values_ptr = + input_values_.empty() ? nullptr : &input_values_[0]; + + const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0]; + TF_Tensor** output_values_ptr = + output_values_.empty() ? nullptr : &output_values_[0]; + + TF_Operation* const* targets_ptr = targets_.empty() ? nullptr : &targets_[0]; + + TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, inputs_.size(), + outputs_ptr, output_values_ptr, outputs_.size(), targets_ptr, + targets_.size(), nullptr, s); + + DeleteInputValues(); +} + +void CSession::CloseAndDelete(TF_Status* s) { + DeleteInputValues(); + ResetOutputValues(); + if (session_ != nullptr) { + TF_CloseSession(session_, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteSession(session_, s); + session_ = nullptr; + } +} + +void CSession::DeleteInputValues() { + for (size_t i = 0; i < input_values_.size(); ++i) { + TF_DeleteTensor(input_values_[i]); + } + input_values_.clear(); +} + +void CSession::ResetOutputValues() { + for (size_t i = 0; i < output_values_.size(); ++i) { + if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]); + } + output_values_.clear(); +} diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0c0ba667bd0c3014efc6f0bd48ad0e63ccf4ee6e --- /dev/null +++ b/tensorflow/c/c_test_util.h @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ + +#include "tensorflow/c/c_api.h" + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/test.h" + +using ::tensorflow::string; + +typedef std::unique_ptr + unique_tensor_ptr; + +// Create a tensor with values of type TF_INT8 provided by `values`. +TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); + +TF_Tensor* Int32Tensor(int32_t v); + +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, + const char* name = "feed"); + +TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, + const char* name = "const"); + +TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + +TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "add"); + +TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, + const char* name = "add"); + +TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s); + +TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); + +bool IsPlaceholder(const tensorflow::NodeDef& node_def); + +bool IsScalarConst(const tensorflow::NodeDef& node_def, int v); + +bool IsAddN(const tensorflow::NodeDef& node_def, int n); + +bool IsNeg(const tensorflow::NodeDef& node_def, const string& input); + +bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); + +bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def); + +bool GetAttrValue(TF_Operation* oper, const char* attr_name, + tensorflow::AttrValue* attr_value, TF_Status* s); + +class CSession { + public: + CSession(TF_Graph* graph, TF_Status* s); + explicit CSession(TF_Session* session); + + ~CSession(); + + void SetInputs(std::vector> inputs); + void SetOutputs(std::initializer_list outputs); + void SetOutputs(const std::vector& outputs); + void SetTargets(std::initializer_list targets); + + void Run(TF_Status* s); + + void CloseAndDelete(TF_Status* s); + + TF_Tensor* output_tensor(int i) { return output_values_[i]; } + + private: + void DeleteInputValues(); + void ResetOutputValues(); + + TF_Session* session_; + std::vector inputs_; + std::vector input_values_; + std::vector outputs_; + std::vector output_values_; + std::vector targets_; +}; + +#endif // THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..adca6c762526a85f015560efb22d3de185e2ae6c --- /dev/null +++ b/tensorflow/c/python_api.cc @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/python_api.h" + +#include "tensorflow/c/c_api_internal.h" + +namespace tensorflow { + +void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { + // TODO(skyewm): make sure cycles are prevented + mutex_lock l(graph->mu); + graph->graph.AddControlEdge(&input->node, &op->node); +} + +void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { + mutex_lock l(graph->mu); + op->node.set_requested_device(device); +} + +} // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h new file mode 100644 index 0000000000000000000000000000000000000000..e1a55d7755a76c778bf6a8120a8cf81adb6941dc --- /dev/null +++ b/tensorflow/c/python_api.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ +#define THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ + +#include "tensorflow/c/c_api.h" + +// These functions can be removed without notice. They exist to facilitate some +// refactoring of graph construction code in the Python API. + +namespace tensorflow { + +void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); + +void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f7e36226903411dc46c958fe1a096277d653ba4 --- /dev/null +++ b/tensorflow/c/while_loop_test.cc @@ -0,0 +1,329 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api.h" + +#include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +using tensorflow::GraphDef; + +namespace { + +class CApiWhileLoopTest : public ::testing::Test { + protected: + CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {} + + ~CApiWhileLoopTest() override { + TF_DeleteGraph(graph_); + TF_DeleteStatus(s_); + } + + void Init(int ninputs) { + DCHECK(inputs_.empty()); + DCHECK_GT(ninputs, 0); + + for (int i = 0; i < ninputs; ++i) { + TF_Operation* placeholder = Placeholder( + graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str()); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inputs_.push_back({placeholder, 0}); + } + + original_graph_description_ = GraphDebugString(); + + params_.reset(new TF_WhileParams( + TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_))); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_EQ(original_graph_description_, GraphDebugString()) + << "TF_NewWhile() altered graph"; + + params_->name = "test_loop"; + + // Initialize outputs_ so we can easily detect errors/bugs + outputs_.resize(ninputs, {nullptr, -1}); + } + + void ExpectOK() { + TF_FinishWhile(params_.get(), s_, &outputs_[0]); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void ExpectError(TF_Code expected_code, const string& expected_msg) { + TF_FinishWhile(params_.get(), s_, &outputs_[0]); + EXPECT_EQ(expected_code, TF_GetCode(s_)); + EXPECT_EQ(expected_msg, TF_Message(s_)); + // TODO(skyewm): this assert is currently broken. Fix or remove guarantee. + // ASSERT_EQ(original_graph_description_, GraphDebugString()) << + // "TF_FinishWhile() altered graph on error"; + } + + void Run(std::initializer_list input_values) { + DCHECK_EQ(inputs_.size(), input_values.size()); + std::vector> inputs(inputs_.size()); + int i = 0; + for (int v : input_values) { + inputs[i] = {inputs_[i].oper, Int32Tensor(v)}; + ++i; + } + csession_.reset(new CSession(graph_, s_)); + csession_->SetInputs(inputs); + csession_->SetOutputs(outputs_); + csession_->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void ExpectOutputValue(int idx, int expected_value) { + TF_Tensor* out = csession_->output_tensor(idx); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); + ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out)); + int32_t* data = static_cast(TF_TensorData(out)); + EXPECT_EQ(expected_value, *data); + } + + // Create a valid conditional graph. Useful for testing unrelated errors. + void CreateCondGraph() { + TF_Operation* one = ScalarConst(1, params_->cond_graph, s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + } + + string GraphDebugString() const { + TF_Buffer* buf = TF_NewBuffer(); + TF_GraphToGraphDef(graph_, buf, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + GraphDef def; + bool success = def.ParseFromArray(buf->data, buf->length); + DCHECK(success); + TF_DeleteBuffer(buf); + return def.DebugString(); + } + + TF_Status* s_; + TF_Graph* graph_; + std::vector inputs_; // The inputs to the while loop + std::vector outputs_; // The final outputs of the while loop + std::unique_ptr params_; + std::unique_ptr csession_; + + private: + // Used to verify that errors don't change graph_ + string original_graph_description_; +}; + +TEST_F(CApiWhileLoopTest, BasicLoop) { + Init(2); + + // Validate TF_WhileParams returned by TF_NewWhile() + EXPECT_TRUE(params_->body_graph != nullptr); + EXPECT_TRUE(params_->cond_graph != nullptr); + + EXPECT_EQ(params_->ninputs, 2); + + ASSERT_TRUE(params_->cond_inputs != nullptr); + ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr); + EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr); + + ASSERT_TRUE(params_->body_inputs != nullptr); + EXPECT_TRUE(params_->body_inputs[0].oper != nullptr); + EXPECT_TRUE(params_->body_inputs[1].oper != nullptr); + + ASSERT_TRUE(params_->body_outputs != nullptr); + + // Create loop: while (input1 < input2) input1 += input2 + 1 + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], params_->cond_inputs[1], + params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1], + params_->body_graph, s_, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* one = ScalarConst(1, params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {add2, 0}; + params_->body_outputs[1] = params_->body_inputs[1]; + + // Finalize while loop + ExpectOK(); + + // Validate while loop outputs returned by TF_FinishWhile() + EXPECT_TRUE(outputs_[0].oper != nullptr); + EXPECT_GE(outputs_[0].index, 0); + EXPECT_TRUE(outputs_[1].oper != nullptr); + EXPECT_GE(outputs_[1].index, 0); + + // Run the graph + Run({-9, 2}); + ExpectOutputValue(0, 3); + ExpectOutputValue(1, 2); +} + +TEST_F(CApiWhileLoopTest, NestedLoop) { + Init(2); + // Create nested loop: + // while (input1 < 6) { + // inner_input1 = input1 + // while (inner_input1 < 3) { + // input2 += 1 + // inner_input1 += 2 + // } + // input1 += input2 + // } + // + // Expected execution with initial values input1 = input2 = 0: + // + // outer inner inner_ + // step# step# input1 input2 input1 + // ------------------------------------ + // 0 0 0 0 0 + // 0 1 0 1 2 + // 0 2 0 2 4 + // 0 - 2 2 - + // 1 0 2 2 2 + // 1 1 2 3 4 + // 1 - 5 3 - + // 2 0 5 3 5 + // 2 - 8 3 - + + // Create outer cond graph + TF_Operation* six = ScalarConst(6, params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + // Create outer body graph + // Init inner graph + TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]}; + TF_WhileParams inner_params = + TF_NewWhile(params_->body_graph, inner_inputs, 2, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.name = "inner_loop"; + + // Create inner cond graph + TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* inner_less_than = LessThan( + inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.cond_output = {inner_less_than, 0}; + + // Create inner body graph + TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* input2_add = + Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.body_outputs[1] = {input2_add, 0}; + + TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two, + inner_params.body_graph, s_, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.body_outputs[0] = {inner_input1_add, 0}; + + // Finalize inner graph + TF_Output inner_outputs[2] = {{nullptr, -1}}; + TF_FinishWhile(&inner_params, s_, inner_outputs); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* input1_add = + Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {input1_add, 0}; + + params_->body_outputs[1] = inner_outputs[1]; + + // Finalize outer graph + ExpectOK(); + + // Check for a few expected nodes + const char* node_name = "test_loop/cond/scalar"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/add"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/inner_loop/body/one"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/inner_loop/cond/less_than"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + + // Run the graph + Run({0, 0}); + ExpectOutputValue(0, 8); + ExpectOutputValue(1, 3); +} + +TEST_F(CApiWhileLoopTest, BadCondOutput) { + Init(1); + params_->body_outputs[0] = params_->body_inputs[0]; + ExpectError(TF_INVALID_ARGUMENT, + "TF_WhileParams `cond_output` field isn't set"); +} + +TEST_F(CApiWhileLoopTest, BadBodyOutput) { + Init(1); + CreateCondGraph(); + ExpectError(TF_INVALID_ARGUMENT, + "TF_WhileParams `body_outputs[0]` field isn't set"); +} + +TEST_F(CApiWhileLoopTest, NullName) { + Init(1); + CreateCondGraph(); + params_->body_outputs[0] = params_->body_inputs[0]; + params_->name = nullptr; + ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null"); +} + +TEST_F(CApiWhileLoopTest, WrongGraph) { + Init(1); + CreateCondGraph(); + // Set body output to output from outer graph + params_->body_outputs[0] = inputs_[0]; + // TODO(skyewm): improve error message + ExpectError(TF_INVALID_ARGUMENT, + "Requested return node 'p0' not found in graph def"); +} + +TEST_F(CApiWhileLoopTest, BadTypes) { + Init(1); + CreateCondGraph(); + // Op that has a float input + output + TF_OperationDescription* desc = TF_NewOperation( + params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op"); + TF_AddInput(desc, params_->body_inputs[0]); + TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + string msg(TF_Message(s_)); + EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while " + "building NodeDef 'float_op'"), + msg.npos); + TF_AbortWhile(params_.get()); +} + +} // namespace diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index b86731b1183fd59875d470aa48ba12c269789f0f..c65170dfe85b847259f7c59d437f62aa32ce1178 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -45,6 +45,7 @@ tf_cc_test( "//tensorflow/core:all_kernels", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -61,7 +62,6 @@ cc_library( ":gradients", ":ops", ":scope", - "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -248,6 +248,7 @@ cc_library( ":gradients", "//tensorflow/core:lib_proto_parsing", ], + alwayslink = 1, ) tf_cc_test( @@ -274,11 +275,8 @@ cc_library( deps = [ ":cc_ops", ":grad_op_registry", - ":ops", - ":scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", ], + alwayslink = 1, ) tf_cc_test( @@ -305,11 +303,8 @@ cc_library( ":cc_ops", ":cc_ops_internal", ":grad_op_registry", - ":ops", - ":scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", ], + alwayslink = 1, ) tf_cc_test( @@ -441,6 +436,7 @@ cc_library_with_android_deps( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", + "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", ], @@ -481,10 +477,23 @@ cc_binary( name = "tutorials_example_trainer", srcs = ["tutorials/example_trainer.cc"], copts = tf_copts(), - linkopts = [ - "-lpthread", - "-lm", - ], + linkopts = select({ + "//tensorflow:windows": [], + "//tensorflow:windows_msvc": [], + "//tensorflow:darwin": [ + "-lm", + "-lpthread", + ], + "//tensorflow:ios": [ + "-lm", + "-lpthread", + ], + "//conditions:default": [ + "-lm", + "-lpthread", + "-lrt", + ], + }), deps = [ ":cc_ops", "//tensorflow/core:core_cpu", @@ -514,7 +523,6 @@ cc_library( deps = [ ":coordinator", "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -547,8 +555,6 @@ cc_library( srcs = ["training/coordinator.cc"], hdrs = ["training/coordinator.h"], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 71aa986f918de68822d457422f6c7a73d6253819..80dd272f6f9dd5eecf5d7002bdf1c7c98e4c3ba3 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -18,8 +18,12 @@ limitations under the License. #include #include "tensorflow/cc/framework/cc_op_gen.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 8c00a6f70497df2c70f266a747197e50c98375bb..66a943410e2757ea5a5c55351c1fc20d5a5e3154 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -65,7 +65,7 @@ class SymbolicGradientBuilder { // gradients for the node associated with `src`. Status BackpropAlongEdge(const Output& dst_grad, const Output& src); - // Adds a node to the graph (returned in`grad`) that sums the in-bound + // Adds a node to the graph (returned in `grad`) that sums the in-bound // gradients to `src` (if there are more than one). Status SumGradients(const Output& src, Output* grad); @@ -152,12 +152,12 @@ Status SymbolicGradientBuilder::Initialize() { grad_outputs_->resize(inputs_.size()); // Populate `output_nodes_` from node ids in `outputs_`. output_nodes_.reserve(outputs_.size()); - for (int i = 0; i < outputs_.size(); ++i) { + for (size_t i = 0; i < outputs_.size(); ++i) { output_nodes_.insert(outputs_[i].node()->id()); } // Populate `input_nodes_` from Outputs in `inputs_`. input_nodes_.reserve(inputs_.size()); - for (int i = 0; i < inputs_.size(); ++i) { + for (size_t i = 0; i < inputs_.size(); ++i) { input_nodes_.insert({inputs_[i], i}); } @@ -341,7 +341,7 @@ Status SymbolicGradientBuilder::AddGradients() { // gradient function to the src node/output to which it should be // backproped. Maybe grad functions can return a vector of Output pairs to // make this association explicit. - int dx_index = 0; + size_t dx_index = 0; for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; if (dx_index == dx.size()) { @@ -352,6 +352,23 @@ Status SymbolicGradientBuilder::AddGradients() { BackpropAlongEdge(dx[dx_index++], {e->src(), e->src_output()})); } } + + // Check if any input nodes still have pending gradients and have not been + // processed yet. This happens if not all outputs of a node are in 'inputs_'. + std::unordered_map requested_grads; + for (const Output& nout : inputs_) { + if (pending_[nout.node()->id()] > 0) { + DCHECK_GT(nout.node()->num_outputs(), 1); + int idx = input_nodes_[nout]; + DCHECK(((*grad_outputs_)[idx].node() == nullptr)); + TF_RETURN_IF_ERROR(SumGradients(nout, &(*grad_outputs_)[idx])); + ++requested_grads[nout.node()]; + } + } + for (const auto& p : requested_grads) { + int num_requested_inputs = p.first->num_outputs() - pending_[p.first->id()]; + CHECK_EQ(num_requested_inputs, p.second); + } return Status::OK(); } diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 6a249825812b4d39b55f7170a35436b6ae88c020..24af7d567b267332610eba2c8c8c57681fa0559b 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -259,6 +260,42 @@ TEST_F(GradientsTest, StackUnstack_StopBackprop) { CompareTestAndExpectedGraphs(); } +TEST_F(GradientsTest, StackUnstack_SubsetOfUnstackOutputs) { + // Constructs an unstack with three outputs, and takes the gradient with + // respect to only two of the outputs. Tests that the output gradients are + // computed. + for (const bool expected : {false, true}) { + const Scope& scope = expected ? scope_expected_ : scope_test_; + // Construct forward graph. + auto c = Const(scope, 1, {3, 4, 2}); + auto unpack = Unstack(scope, c, 3); + auto x = Identity(scope, unpack.output[0]); + auto y = Identity(scope, unpack.output[1]); + auto z = Identity(scope, unpack.output[2]); + TF_ASSERT_OK(scope.status()); + + // Construct grad inputs. + auto dy = Const(scope, 4, {4, 2}); + auto dz = Const(scope, 5, {4, 2}); + + if (expected) { + // Construct backward graph. + auto g1 = Identity(scope, dy); + auto g2 = Identity(scope, dz); + } else { + // Call AddSymbolicGradients. + std::vector grad_outputs; + TF_ASSERT_OK(AddSymbolicGradients(scope, {y, z}, + {unpack.output[1], unpack.output[2]}, + {dy, dz}, &grad_outputs)); + ASSERT_EQ(grad_outputs.size(), 2); + EXPECT_TRUE(grad_outputs[0].node() != nullptr); + EXPECT_TRUE(grad_outputs[1].node() != nullptr); + } + } + CompareTestAndExpectedGraphs(); +} + TEST_F(GradientsTest, DependentGradOutputs) { // Tests that dependent gradients (in this case the gradients w.r.t to the // output and one input of MatMul) are computed properly. diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 32c0822de69da7989ceaa4028539db928b6fcea3..1948dd4e46b932775fdb5cbbdad7b66338b0fcf4 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -136,7 +136,7 @@ Scope::Impl::Impl(const std::shared_ptr& graph, Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = - new ShapeRefiner(graph->versions().producer(), graph->op_registry()); + new ShapeRefiner(graph->versions(), graph->op_registry()); return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner)); } diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 37f07e71a0dff9144f193679bbcfcf581c1538cf..6545e4ee3eb406436937a43ddac66d017af8e108 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -100,6 +100,17 @@ Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); +Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + grad_outputs->push_back(Identity(scope, grad_inputs[0])); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad); + Status SplitGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -247,6 +258,18 @@ Status ScatterNdGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad); +Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto indices = op.input(1); + grad_outputs->push_back(Identity(scope, grad_inputs[0])); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); + return scope.status(); +} +REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad); + +template Status PadGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -259,9 +282,14 @@ Status PadGrad(const Scope& scope, const Operation& op, auto begin = Reshape(scope, pad_before, {-1}); grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x))); grad_outputs->push_back(NoGradient()); + // PadV2 adds a "constant_values" input. + if (IsPadV2) { + grad_outputs->push_back(NoGradient()); + } return scope.status(); } -REGISTER_GRADIENT_OP("Pad", PadGrad); +REGISTER_GRADIENT_OP("Pad", PadGrad); +REGISTER_GRADIENT_OP("PadV2", PadGrad); Status SpaceToBatchGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 5798b5b509fc14e6c9d95d4fd42aca893254f775..1777e181451b267f52a418888912ed1393bdf8b1 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -233,6 +233,28 @@ TEST_F(ArrayGradTest, ScatterNdGrad_SliceIndexing) { RunTest(updates, updates_shape, y, y_shape); } +TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SimpleIndexing) { + TensorShape updates_shape({4}); + TensorShape input_shape({8}); + auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape)); + auto updates = + Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); + auto indices = Const(scope_, {{4}, {3}, {1}, {7}}); + auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates); + RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); +} + +TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SliceIndexing) { + TensorShape updates_shape({2, 4, 4}); + TensorShape input_shape({4, 4, 4}); + auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape)); + auto updates = + Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); + auto indices = Const(scope_, {{0}, {2}}); + auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates); + RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); +} + TEST_F(ArrayGradTest, PadGrad) { TensorShape x_shape({2, 3}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 71d9a8ed7be5ea75a3b26224df871b955f05c132..0b9b665b1eb4420827b152a88d9023ceab4d932d 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -203,6 +203,46 @@ Status TanhGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Tanh", TanhGrad); +Status AsinhGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = asinh(x) + // dy/dx = 1 / cosh(y) + auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Asinh", AsinhGrad); + +Status AcoshGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = acosh(x) + // dy/dx = 1 / sinh(y) + auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Acosh", AcoshGrad); + +Status AtanhGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = atanh(x) + // dy/dx = 1 / (1 - x^2) + auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Atanh", AtanhGrad); + Status SigmoidGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1653b04378f30bd788d549da04d4140ac7d6317e..48b3ddbe90c2313ec0aa50729f277a1c258de52c 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -48,6 +48,9 @@ class CWiseUnaryGradTest : public ::testing::Test { SINH, COSH, TANH, + ASINH, + ACOSH, + ATANH, SIGMOID, SIGN, SIN, @@ -122,6 +125,15 @@ class CWiseUnaryGradTest : public ::testing::Test { case TANH: y = Tanh(scope_, x); break; + case ASINH: + y = Asinh(scope_, x); + break; + case ACOSH: + y = Acosh(scope_, x); + break; + case ATANH: + y = Atanh(scope_, x); + break; case SIGMOID: y = Sigmoid(scope_, x); break; @@ -413,6 +425,76 @@ TEST_F(CWiseUnaryGradTest, Tanh_Complex) { TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); } +TEST_F(CWiseUnaryGradTest, Asinh) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + auto y = std::asinh(x); + return dy / std::cosh(y); + }; + TestCWiseGrad(ASINH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Asinh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + auto y = std::asinh(x); + return dy / conjugate(std::cosh(y)); + }; + TestCWiseGrad(ASINH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Acosh) { + auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7}); }; + auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13, 14}); }; + auto dx_fn = [this](const float x, const float dy) { + auto y = std::acosh(x); + return dy / std::sinh(y); + }; + TestCWiseGrad(ACOSH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Acosh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 1}, {2, 1}, {1, 4}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{2, 2}, {3, 3}, {1, 4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + auto y = std::acosh(x); + return dy / conjugate(std::sinh(y)); + }; + TestCWiseGrad(ACOSH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Atanh) { + auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * (1. / (1. - x * x)); + }; + TestCWiseGrad(ATANH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Atanh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / conjugate(one_ - x * x); + }; + TestCWiseGrad(ATANH, x_fn, dy_fn, dx_fn); +} + TEST_F(CWiseUnaryGradTest, Sigmoid) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 5e5203d09055d65cb1dcc16e091f6e5028ee7ae1..f9d69ff8967e7c7d56f5771a8ccbd4091f7bc8c0 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -46,6 +46,19 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); +Status LogSoftmaxGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + + auto softmax = Exp(scope, op.output(0)); + auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true)); + auto mul = Mul(scope, sum, softmax); + auto dx = Sub(scope, grad_inputs[0], mul); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LogSoftmax", LogSoftmaxGrad); + Status ReluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -73,6 +86,15 @@ Status EluGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Elu", EluGradHelper); +Status SeluGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SeluGrad(scope, grad_inputs[0], op.output(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Selu", SeluGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 70c9bd4e08b2b46866a44becc8fe1305fec48ea9..eab5b446261cc7c69a4aa3b26a2debd402c9bdd9 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -57,6 +57,19 @@ TEST_F(NNGradTest, SoftmaxGrad) { RunTest(x, shape, y, shape); } +TEST_F(NNGradTest, LogSoftmaxGrad) { + TensorShape shape({5, 3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = LogSoftmax(scope_, x); + // Avoid numerical instability when computing finite differences. + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, + 0.1f, 0.3f, 0.5f, 0.7f, 0.8f, + -0.1f, 0.1f, 0.1f, 0.1f, 1.2f}, + {5, 3}); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, ReluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); @@ -90,5 +103,15 @@ TEST_F(NNGradTest, EluGrad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, SeluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Selu(scope_, x); + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt index 1dffb10c03379571907e921c1add98d1f11625c3..2252cbb2892af9b0d9938a7864235d3d6b4ec005 100644 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -100,6 +100,10 @@ op { name: "Stack" skip: true } op { name: "StackClose" skip: true } op { name: "StackPop" skip: true } op { name: "StackPush" skip: true } +op { name: "StackV2" skip: true } +op { name: "StackCloseV2" skip: true } +op { name: "StackPopV2" skip: true } +op { name: "StackPushV2" skip: true } op { name: "TensorArrayCloseV2" skip: true } op { name: "TensorArrayCloseV3" rename_to: "TensorArrayClose" } @@ -173,6 +177,7 @@ op { name: "MaxPoolGradWithArgmax" hide: true } op { name: "ReluGrad" hide: true } op { name: "Relu6Grad" hide: true } op { name: "EluGrad" hide: true } +op { name: "SeluGrad" hide: true } op { name: "SoftplusGrad" hide: true } op { name: "SoftsignGrad" hide: true } op { name: "FractionalAvgPoolGrad" hide: true } diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 94a3b3cf465a279e3bb44344739499ad670119c3..c940df8a8761d97a859be3af30980ff79ca3577a 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -21,6 +21,9 @@ namespace tensorflow { /// SavedModel assets directory. constexpr char kSavedModelAssetsDirectory[] = "assets"; +/// SavedModel assets.extra directory. +constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra"; + /// SavedModel assets key for graph collection-def. constexpr char kSavedModelAssetsKey[] = "saved_model_assets"; diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 807f5904afcf36890f4bd02f0d811a3ebe0cceba..f98abc8a817eca7bc129bb03a2ad31b97d957065 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/saved_model.pb.h" @@ -76,8 +77,16 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, return Status::OK(); } } + string tags_as_string = "{ "; + for (const string& tag : tags) { + tags_as_string = strings::StrCat(tags_as_string, tag, " "); + } + tags_as_string = strings::StrCat(tags_as_string, "}"); return Status(error::Code::NOT_FOUND, - "Could not find meta graph def matching supplied tags."); + "Could not find meta graph def matching supplied tags: " + + tags_as_string + + ". To inspect available tag-sets in the SavedModel, please " + "use the SavedModel CLI: `saved_model_cli`"); } Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index cef29e7b071e538a60193fd998acc0fb29c2cea3..0ad6b33bba5fcceaca68e2f179cef2232c689a80 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -133,9 +133,9 @@ TEST_F(LoaderTest, NoTagMatch) { Status st = LoadSavedModel(session_options, run_options, export_dir, {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE( - StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags.")) + EXPECT_TRUE(StringPiece(st.error_message()) + .contains("Could not find meta graph def matching supplied " + "tags: { missing-tag }")) << st.error_message(); } @@ -151,7 +151,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { EXPECT_FALSE(st.ok()); EXPECT_TRUE( StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags.")) + .contains("Could not find meta graph def matching supplied tags: ")) << st.error_message(); } diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h index 48ab1158e462af25c27a728e404a041516e82057..2b0b2d5c7fb33768494c1781669c1adcb875a579 100644 --- a/tensorflow/cc/saved_model/tag_constants.h +++ b/tensorflow/cc/saved_model/tag_constants.h @@ -18,10 +18,13 @@ limitations under the License. namespace tensorflow { +/// Tag for the `gpu` graph. +constexpr char kSavedModelTagGpu[] = "gpu"; + /// Tag for the `serving` graph. constexpr char kSavedModelTagServe[] = "serve"; -/// Tag for the `training` graph.` +/// Tag for the `training` graph. constexpr char kSavedModelTagTrain[] = "train"; } // namespace tensorflow diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 1f6fe28188cfbb6a64935e4a3f70cf8e0f6eb9ad..f956602ba221bbbb3c2fc9c7df7d452da833c002 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -108,6 +108,7 @@ cc_test( deps = [ ":tfcompile_lib", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -126,16 +127,7 @@ cc_library( deps = [ ":tfcompile_lib", ":tfcompile_proto", - "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags", - "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags", - "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", - "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", - "//tensorflow/compiler/xla/legacy_flags:service_flags", - "//tensorflow/compiler/xla/legacy_flags:util_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -161,6 +153,40 @@ tf_library( tags = ["manual"], ) +# A test of tf_library that includes a graph with an unknown op, but where +# the compilation works because the unknown op is not needed for the fetches. +tf_library( + name = "test_graph_tfunknownop", + testonly = 1, + config = "test_graph_tfunknownop.config.pbtxt", + cpp_class = "UnknownOpAddComp", + graph = "test_graph_tfunknownop.pbtxt", + tags = ["manual"], +) + +# A test of tf_library that includes a graph with an unknown op, but where +# the compilation works because the op between the unknown op and the +# fetches is a feed. +tf_library( + name = "test_graph_tfunknownop2", + testonly = 1, + config = "test_graph_tfunknownop2.config.pbtxt", + cpp_class = "UnknownOpAddComp", + graph = "test_graph_tfunknownop.pbtxt", + tags = ["manual"], +) + +# A test of tf_library that includes a graph with an unknown op, but where +# the compilation works because the unknown op is fed. +tf_library( + name = "test_graph_tfunknownop3", + testonly = 1, + config = "test_graph_tfunknownop3.config.pbtxt", + cpp_class = "UnknownOpAddComp", + graph = "test_graph_tfunknownop.pbtxt", + tags = ["manual"], +) + # Utility library for benchmark binaries, used by the *_benchmark rules that are # added by the tfcompile bazel macro. cc_library( @@ -204,6 +230,7 @@ test_suite( tests = [ ":benchmark_test", ":test_graph_tfadd_test", + ":test_graph_tfunknownop_test", "//tensorflow/compiler/aot/tests:all_tests", ], ) diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index ca17c5ab690f606bd531638fece8b0a74cdd8c18..03bdd63623dcd1176c4598107281db9ad72e1947 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -77,66 +79,51 @@ Status DumpGraph(const MainFlags& flags, const string& name, return WriteTextProto(Env::Default(), file, graph_def); } -string TensorIdToString(const TensorId& id) { - return strings::StrCat(id.node_name(), ":", id.output_index()); -} - typedef std::unordered_map NodeMap; // Each feed id identifies the positional output of some node, which may consist -// of multiple edges. For each feed node, replaces all matching edges so that -// they point from a new _Arg node instead. +// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed +// tensor with a placeholder. For each feed tensor, replaces all edges so they +// point from a new _Arg node instead. Status AddArgNodes(Graph* graph, const NodeMap& node_map, - const protobuf::RepeatedPtrField& feeds) { + const protobuf::RepeatedPtrField& feeds, + const std::unordered_map& feed_remapping) { for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { const Feed& feed = feeds[arg_index]; - const TensorId& id = feed.id(); - auto it = node_map.find(id.node_name()); - if (it == node_map.end()) { - return errors::NotFound("Can't find feed id: ", TensorIdToString(id)); - } - const Node* feed_node = it->second; - if (id.output_index() >= feed_node->num_outputs()) { - return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id), - ", output index should be < ", - feed_node->num_outputs()); - } - // TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr - // if we can determine it. That way the graph will be initialized with - // whatever shapes we can infer, while the user can still explicitly specify - // or override them. + // All feeds have been replaced by placeholders. + const int output_index = 0; + + const auto remap_it = feed_remapping.find(TensorIdToString(feed.id())); + auto node_it = node_map.find(remap_it->second); + const Node* feed_node = node_it->second; + + // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a + // "_shape" attr if we can determine it. That way the graph will be + // initialized with whatever shapes we can infer, while the user can still + // explicitly specify or override them. Node* arg_node = nullptr; TF_RETURN_IF_ERROR( NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) - .Attr("T", BaseType(feed_node->output_type(id.output_index()))) + .Attr("T", BaseType(feed_node->output_type(output_index))) .Attr("index", arg_index) - .Attr(kFeedIdAttr, TensorIdToString(id)) + .Attr(kFeedIdAttr, TensorIdToString(feed.id())) .Attr(kShapeAttr, TensorShape(feed.shape())) .Attr(kDebugNameAttr, feed.name()) .Finalize(graph, &arg_node)); + // Collects out-edges from the feed node that have a matching edge index; - // these will be replaced with edges from the arg node instead. Also - // replaces all control edges from Placeholder feed nodes; similar code - // exists in subgraph::RewriteGraphForExecution. - // TODO(toddw): Why only replace control edges from Placeholder? + // these will be replaced with edges from the arg node instead. // // We must collect the edges first and process them in a second pass, since // removing the edge from the graph invalidates feed_node->out_edges. std::vector feed_edges; for (const Edge* edge : feed_node->out_edges()) { - if (edge->src_output() == id.output_index() || - (edge->src_output() == Graph::kControlSlot && - feed_node->type_string() == "Placeholder")) { + if (edge->src_output() == output_index) { feed_edges.push_back(edge); } } for (const Edge* edge : feed_edges) { - if (edge->src_output() == id.output_index()) { - graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input()); - } else { - CHECK_EQ(edge->src_output(), Graph::kControlSlot); - graph->AddControlEdge(arg_node, edge->dst()); - } + graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); } } @@ -178,13 +165,16 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, // fetch ids respectively), and rewrites the edges so that inputs flow from _Arg // nodes, and outputs flow to _Retval nodes. This allows the symbolic graph // execution to know the input and output args for the generated function. -Status RewriteAndPruneGraph(Graph* graph, const Config& config, - const MainFlags& flags) { +Status RewriteAndPruneGraph( + Graph* graph, const Config& config, + const std::unordered_map& feed_remapping, + const MainFlags& flags) { NodeMap node_map; for (Node* n : graph->nodes()) { node_map[n->name()] = n; } - TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed())); + TF_RETURN_IF_ERROR( + AddArgNodes(graph, node_map, config.feed(), feed_remapping)); std::unordered_set retval_nodes; TF_RETURN_IF_ERROR( AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); @@ -265,7 +255,9 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TensorShape shape; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape)); TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } @@ -378,14 +370,32 @@ Status CompileXla(xla::CompileOnlyClient* client, Status InitGraph(const GraphDef& graph_def, const Config& config, const MainFlags& flags, std::unique_ptr* graph) { TF_RETURN_IF_ERROR(ValidateConfig(config)); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); std::unique_ptr g(new Graph(flib_def)); - GraphDef copy_def(graph_def); - TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(), - 0 /*node_offset*/)); + + // Replace references to fed tensors with references to newly added + // placeholders. + GraphDef first_copy_def = graph_def; + + // Maps from name:port of a feed to the name:port of the placeholder to use. + std::unordered_map feed_remapping; + TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(), + &feed_remapping, &first_copy_def)); + + // Prune the GraphDef first so that unknown ops that we aren't compiling get + // filtered out. + GraphDef second_copy_def; + TF_RETURN_IF_ERROR( + PruneGraphDefInto(config, first_copy_def, &second_copy_def)); + + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( + &second_copy_def, *g->op_registry(), 0 /*node_offset*/)); + + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), + second_copy_def, g.get())); TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get())); - TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags)); + RewriteAndPruneGraph(g.get(), config, feed_remapping, flags)); *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt index 5625c0ab03893c997245a6449d145b9149b48627..f2d9c34b2d1d68aa80245a6f3379b3759bb9f4b9 100644 --- a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt @@ -6,7 +6,7 @@ feed { } } feed { - id { node_name: "y_const" } + id { node_name: "y_reshape" } shape { dim { size: 1 } } diff --git a/tensorflow/compiler/aot/test_graph_tfadd.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt index 91c900e06d7547fe9a377a427b6ca56b9e46942d..665c9fe28721b25c544c30ecd1b4dfc399934314 100644 --- a/tensorflow/compiler/aot/test_graph_tfadd.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt @@ -4,15 +4,7 @@ node { attr { key: "value" value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } + tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } int_val: 1 } } } attr { @@ -28,15 +20,7 @@ node { attr { key: "value" value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } + tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } int_val: 2 } } } attr { @@ -46,11 +30,20 @@ node { } } } +node { + name : "y_reshape" + op : "Reshape" + input : "y_const" + input : "y_shape" + attr { key: "T" value { type: DT_INT32 } } + # Attribute TShape not specified; needs to be set to its default + # by tfcompile. +} node { name : "x_y_sum" op : "Add" input : "x_const" - input : "y_const" + input : "y_reshape" attr { key : "T" value { diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5625c0ab03893c997245a6449d145b9149b48627 --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_sum" } +} diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..48b881bb9462dc30944a1377d4d2a2c58b9dfe43 --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt @@ -0,0 +1,58 @@ +node { + name : "x_const" + op : "Const" + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { dim { size: 1 } } + int_val: 1 + } + } + } + attr { key : "dtype" value { type: DT_INT32 } } +} +node { + name : "y_const" + op : "Const" + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { dim { size: 1 } } + int_val: 2 + } + } + } + attr { key: "dtype" value { type: DT_INT32 } } +} +node { + name : "x_y_sum" + op : "Add" + input : "x_const" + input : "y_const" + attr { key : "T" value { type: DT_INT32 } } +} +node { + name : "z" + op : "SomeUnknownOp" + input : "x_const" +} +node { + name : "z_identity" + op : "Identity" + input : "z:1" + attr { key : "T" value { type: DT_INT32 } } +} +node { + name : "x_z_sum" + op : "Add" + input : "x_const" + input : "z_identity" + attr { key : "T" value { type: DT_INT32 } } +} +versions { + producer: 15 +} diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7370ed370d314052ed23d4ceca22cab7def65485 --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt @@ -0,0 +1,25 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "z_identity"} + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_sum" } +} +fetch { + id { node_name: "x_z_sum" } +} diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b2d7d5457427775fe2f00e079ced6b23c3308230 --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt @@ -0,0 +1,26 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "z" output_index: 1} + shape { + dim { size: 1 } + } + type: DT_INT32 +} +fetch { + id { node_name: "x_y_sum" } +} +fetch { + id { node_name: "x_z_sum" } +} diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 4be4e0fbb39c710e64478ee6b98a8dd1fc0441b9..12e1485b484d6cb9f3f896db567e9a6fae719943 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -27,6 +27,15 @@ def tf_library(name, graph, config, deps=None, tags=None): """Runs tfcompile to compile a TensorFlow graph into executable code. + Given an invocation of tf_library(name="foo", ...), generates the following + build targets: + foo: A cc_library containing the generated header and computation. + foo_test: A cc_test with simple tests and benchmarks. Only created if + gen_test=True. + foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful + for mobile devices or other platforms that can't compile the + full test libraries. Only created if gen_benchmark=True. + Args: name: The name of the build rule. graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/aot/tfcompile.proto index be3f5043501c71c844a00b5a5b23fa4285c00ec6..cd83840d894f2a28ca70c54f3320a6287b4a0a20 100644 --- a/tensorflow/compiler/aot/tfcompile.proto +++ b/tensorflow/compiler/aot/tfcompile.proto @@ -7,6 +7,7 @@ option java_multiple_files = true; option java_package = "org.tensorflow.tfcompile"; import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; // TensorId identifies a tensor in a TensorFlow graph, by specifying the output // index of a particular node in the graph. If the output of the named node @@ -23,6 +24,12 @@ message Feed { TensorId id = 1; TensorShapeProto shape = 2; string name = 3; // Optional name for generated code. + + // Optional data type. This is not normally required, as the graph itself + // contains this information. However, if the node being fed is an op that + // is not linked into the tfcompile binary, then the type cannot be inferred + // from the node; in this case, the type should be set here. + DataType type = 4; }; // Fetch represents a single fetch tensor in the graph, which corresponds to an diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 6fed46b4329606baeed21dd9ee4d34849a7c50a0..be2cfe4734e0493ba41a1bda23606a65d2cb4af4 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -23,16 +23,7 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/tfcompile.pb.h" #include "tensorflow/compiler/aot/tfcompile_util.h" -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/util_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -133,19 +124,11 @@ int main(int argc, char** argv) { flags.target_triple = "x86_64-pc-linux"; flags.out_object = "out.o"; flags.out_header = "out.h"; + flags.entry_point = "entry"; std::vector flag_list; AppendMainFlags(&flag_list, &flags); - xla::legacy_flags::AppendAliasAnalysisFlags(&flag_list); - xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list); - xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); - xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); - xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendLlvmUtilFlags(&flag_list); - xla::legacy_flags::AppendServiceFlags(&flag_list); - xla::legacy_flags::AppendUtilFlags(&flag_list); tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/aot/tfcompile_util.cc index fd073a2e2623b4b24ddc58360525886f3fc1b3ac..e6a4705b6c24eccac6528c64d030f9e37eb5c3f4 100644 --- a/tensorflow/compiler/aot/tfcompile_util.cc +++ b/tensorflow/compiler/aot/tfcompile_util.cc @@ -15,13 +15,19 @@ limitations under the License. #include "tensorflow/compiler/aot/tfcompile_util.h" +#include #include +#include #include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace tfcompile { @@ -115,5 +121,164 @@ Status ValidateConfig(const Config& config) { return Status::OK(); } +Status AddPlaceholdersForFeeds( + const Config& config, const OpRegistryInterface* op_registry, + std::unordered_map* feed_remapping, GraphDef* graph_def) { + struct PlaceholderInfo { + const Feed* feed = nullptr; // point to Feed in . + string placeholder_name; + DataType data_type = DT_INVALID; + }; + + // Put each fed tensor into a map by name:port. A map is used for determinism + // when creating placeholders (genrules want deterministic output). + std::map placeholder_info; + for (int i = 0; i < config.feed_size(); ++i) { + const Feed* feed = &config.feed(i); + const string name_port = TensorIdToString(feed->id()); + auto& info = placeholder_info[name_port]; + info.feed = feed; + info.placeholder_name = strings::StrCat( + "aot_feed_", feed->id().output_index(), "/", feed->id().node_name()); + (*feed_remapping)[name_port] = info.placeholder_name; + } + + // Verify node exists and determine data type. + std::unordered_map name_to_node; + for (int i = 0; i < graph_def->node_size(); ++i) { + name_to_node[graph_def->node(i).name()] = &graph_def->node(i); + } + for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { + PlaceholderInfo& info = it->second; + const TensorId& feed_id = info.feed->id(); + + // Find the existing node and determine data type. + auto node_it = name_to_node.find(feed_id.node_name()); + if (node_it == name_to_node.end()) { + return errors::NotFound("Can't find feed node: ", + TensorIdToString(feed_id)); + } + const NodeDef* existing = node_it->second; + + if (info.feed->type() != DT_INVALID) { + info.data_type = info.feed->type(); + } else { + // Build the node in order to infer its type. + + // Must first add default attrs as well, so do this in a copied GraphDef. + GraphDef gd; + *gd.mutable_versions() = graph_def->versions(); + *gd.add_node() = *existing; + TF_RETURN_IF_ERROR( + AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/)); + + // Now build the node from the copied node def. + Graph g(op_registry); + g.set_versions(graph_def->versions()); + Status status; + Node* feed_node = g.AddNode(gd.node(0), &status); + TF_RETURN_IF_ERROR(status); + info.data_type = + BaseType(feed_node->output_type(info.feed->id().output_index())); + } + } + + // Create placeholders. Note that we could avoid creating a placeholder for + // feeds which are already placeholders, but we omit that to avoid more cases + // in this code. + for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { + const PlaceholderInfo& info = it->second; + NodeDef* d = graph_def->add_node(); + d->set_name(info.placeholder_name); + d->set_op("PlaceholderV2"); + auto& attr_map = *d->mutable_attr(); + attr_map["dtype"].set_type(info.data_type); + *attr_map["shape"].mutable_shape() = info.feed->shape(); + } + + // Rewrite references to the fed tensors to refer to the placeholder. + for (int i = 0; i < graph_def->node_size(); ++i) { + NodeDef* node_def = graph_def->mutable_node(i); + for (int j = 0; j < node_def->input_size(); ++j) { + auto id = ParseTensorName(node_def->input(j)); + auto it = placeholder_info.find(id.ToString()); + if (it != placeholder_info.end()) { + node_def->set_input(j, it->second.placeholder_name); + } + } + } + + return Status::OK(); +} + +Status PruneGraphDefInto(const Config& config, const GraphDef& in, + GraphDef* out) { + *out = in; + out->clear_node(); + + // Tensors needed for feeding. + std::set> feed_tensors; + for (const auto& feed_config : config.feed()) { + feed_tensors.insert(std::make_pair(feed_config.id().node_name(), + feed_config.id().output_index())); + } + + // Maps node name to reachability. + std::unordered_map> node_by_name; + for (const NodeDef& node : in.node()) { + node_by_name[node.name()] = std::pair(false, &node); + } + + // Traverse. + std::queue name_queue; + for (int i = 0; i < config.fetch_size(); ++i) { + name_queue.push(config.fetch(i).id().node_name()); + } + while (!name_queue.empty()) { + const string name = name_queue.front(); + name_queue.pop(); + + auto find_it = node_by_name.find(name); + if (find_it == node_by_name.end()) { + return errors::InvalidArgument("While pruning graph, node ", name, + " needed but not found in the graph."); + } + auto& map_entry = find_it->second; + if (map_entry.first) { + continue; + } + map_entry.first = true; + + // Push input nodes of the currently visited node to name_queue. + for (const string& in_edge : map_entry.second->input()) { + auto id = ParseTensorName(in_edge); + const string node_name = id.first.ToString(); + if (feed_tensors.find(std::make_pair(node_name, id.second)) == + feed_tensors.end()) { + name_queue.push(node_name); + } else { + // The input tensor is from an edge that is being fed. Therefore, + // we skip recursing down that edge, to avoid requiring nodes that + // may not be needed (note that the input node may still be added + // to name_queue later if one of its output edges is not being fed). + } + } + } + + // Copy over, preserving order of original and only nodes that are reachable + // from the fetches. + out->mutable_node()->Reserve(in.node_size()); + for (const NodeDef& node : in.node()) { + if (node_by_name[node.name()].first) { + *out->add_node() = node; + } + } + return Status::OK(); +} + +string TensorIdToString(const TensorId& id) { + return strings::StrCat(id.node_name(), ":", id.output_index()); +} + } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/aot/tfcompile_util.h index 651d75d0d02bdac110159996498778d2c57ddf78..365f7b0e7b19a495ade13a7cff4140cdae68cad2 100644 --- a/tensorflow/compiler/aot/tfcompile_util.h +++ b/tensorflow/compiler/aot/tfcompile_util.h @@ -16,7 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ #define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ +#include + #include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,6 +34,23 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg); // ValidateConfig returns OK iff config is valid. Status ValidateConfig(const Config& config); +// Modifies to include placeholders for each fed tensor, and +// update references to the fed tensors to refer to the placeholders. +// The existing nodes referenced by the feeds are not removed or modified +// (except where their input edges are modified by the replacement of other +// feeds). +Status AddPlaceholdersForFeeds( + const Config& config, const OpRegistryInterface* op_registry, + std::unordered_map* feed_remapping, GraphDef* graph_def); + +// Returns in a copy of , pruned to only include fetches from +// . +Status PruneGraphDefInto(const Config& config, const GraphDef& in, + GraphDef* out); + +// Returns node:port for the given . +string TensorIdToString(const TensorId& id); + } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/aot/tfcompile_util_test.cc index c321d3ff4c779fbd2e9c67dfc1eb24c734a9103f..5a92851ceb972ca63a8a3845eb4730fe198762dd 100644 --- a/tensorflow/compiler/aot/tfcompile_util_test.cc +++ b/tensorflow/compiler/aot/tfcompile_util_test.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -180,6 +182,65 @@ TEST(ValidateConfig, ConflictingFetchName) { ExpectErrorContains(ValidateConfig(config), "conflicting fetch name"); } +static Config FetchesConfig(std::vector fetches) { + Config config; + for (const auto& fetch_node_name : fetches) { + auto* fetch = config.add_fetch(); + fetch->set_name(strings::StrCat("fetch_", fetch_node_name)); + fetch->mutable_id()->set_node_name(fetch_node_name); + } + return config; +} + +TEST(PruneGraphDefInto, Basic) { + GraphDef def; + auto* n = def.add_node(); + n->set_name("a"); + n->add_input("b:0"); + n->add_input("^c"); + + GraphDef copy; + ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"missing"}), def, ©), + "node missing needed"); + ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©), + "node b needed"); + + n = def.add_node(); + n->set_name("b"); + ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©), + "node c needed"); + n->add_input("d:1"); + + n = def.add_node(); + n->set_name("c"); + n->add_input("d:1"); + + n = def.add_node(); + n->set_name("d"); + + // Graph is full, no pruning done. + // Graph right now has diamond from d: + // d --> b --> a + // d --> c --> a + TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©)); + EXPECT_EQ(def.DebugString(), copy.DebugString()); + GraphDef pruned_a = copy; + + // Add some unrelated fields that use b and c, but are not needed for a. + n = def.add_node(); + n->set_name("e"); + n->add_input("^d"); + n->add_input("b:2"); + copy.Clear(); + TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©)); + EXPECT_EQ(pruned_a.DebugString(), copy.DebugString()); + + // Fetch "a" and "e" to get the original graph. + copy.Clear(); + TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a", "e"}), def, ©)); + EXPECT_EQ(def.DebugString(), copy.DebugString()); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5f857191da78ddd68c5689f9c4f467c01300ca7c..625eb08f1b5a334b0b5b44324c27cab93772a177 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -15,27 +15,16 @@ package_group( ) package( - default_visibility = [":internal"], + default_visibility = [ + ":internal", + "//tensorflow/compiler/plugin/executor:__pkg__", + ], ) load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -# This target can be used by XLA device plugins to prevent circular -# dependencies, and provides access to all of the required headers -# for building a device library. -cc_header_only_library( - name = "xla_jit_headers_lib", - visibility = ["//visibility:public"], - deps = [ - ":xla_cpu_device", - ":xla_cpu_jit", - ":xla_gpu_device", - ":xla_gpu_jit", - ], -) - # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( name = "jit", @@ -150,6 +139,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:tensorflow_opensource", "//tensorflow/core/kernels:constant_op", @@ -243,18 +233,38 @@ cc_library( hdrs = ["union_find.h"], ) +cc_test( + name = "graph_to_functiondef_test", + size = "small", + srcs = [ + "graph_to_functiondef_test.cc", + ], + deps = [ + ":graph_to_functiondef", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_test( name = "compilation_passes_test", size = "small", srcs = [ "encapsulate_subgraphs_pass_test.cc", - "graph_to_functiondef_test.cc", "mark_for_compilation_pass_test.cc", ], deps = [ ":common", ":compilation_passes", - ":graph_to_functiondef", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -283,3 +293,15 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. +cc_header_only_library( + name = "xla_jit_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], +) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 14d8f2ab351bd99dd3fe42a9ac6e31062d552ff0..a1ddad3e9b8191ee4d783136d2b509ec15d993d1 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc index 5cdbebd88ee458e5ff332c7a3fe5d736af112ca9..6fa21fa6204dcc9446081d07e2a59ccace216713 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -151,8 +151,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, argdef->set_type(type); const string normalized = node_names.Normalize(node->name()); argdef->set_name(normalized); - CHECK_EQ(node->in_edges().size(), 1) << node->DebugString(); - Edge const* edge = *node->in_edges().begin(); + Edge const* edge; + TF_CHECK_OK(node->input_edge(0, &edge)); return_values[normalized] = strings::StrCat(edge->src()->name(), ":", edge->src_output()); continue; diff --git a/tensorflow/compiler/jit/graph_to_functiondef_test.cc b/tensorflow/compiler/jit/graph_to_functiondef_test.cc index 5c09e96a4c2817e5a871a91ca6c68de87dc3b762..676db7c4dd2fd7047e8ae9bb190daf18af6ac7cf 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef_test.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef_test.cc @@ -82,5 +82,38 @@ TEST(GraphToFunctionDefTest, Basics) { EXPECT_TRUE(fdefs_equal) << diff; } +// Regression test for a crash if there was a control edge to a _Retval node. +TEST(GraphToFunctionDefTest, ControlDependencies) { + Scope root = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(root.WithOpName("a"), DT_FLOAT, 0); + auto b = ops::Neg(root.WithOpName("b").WithControlDependencies(a), a); + auto c = ops::_Retval(root.WithOpName("c").WithControlDependencies(b), b, 0); + + GraphDef graph_def; + TF_EXPECT_OK(root.ToGraphDef(&graph_def)); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphConstructorOptions options; + TF_EXPECT_OK(ConvertGraphDefToGraph(options, graph_def, graph.get())); + + FunctionDef fdef; + TF_EXPECT_OK(GraphToFunctionDef(*graph, "test_fn", &fdef)); + + FunctionDef fdef_expected = FunctionDefHelper::Create( + "test_fn", // function name + {"a: float"}, // inputs + {"c: float"}, // outputs + {}, // attrs + { + // nodes in the function body + {{"b"}, "Neg", {"a", "^a"}, {{"T", DT_FLOAT}}}, + }, + {{"c", "b:y:0"}}); // return values + + string diff; + bool fdefs_equal = EqualFunctionDef(fdef_expected, fdef, &diff); + EXPECT_TRUE(fdefs_equal) << diff; +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index c4116cb8b52adc191e9f695bc9a6e0cf413b4b5c..354c0fabfc78bcb9f5d63e84edc224fc33650ea9 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -2,6 +2,7 @@ licenses(["notice"]) # Apache 2.0 package( default_visibility = [ + "//tensorflow/compiler/plugin/executor:__pkg__", "//tensorflow/compiler/tf2xla:internal", ], ) @@ -35,9 +36,11 @@ cc_library( "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc index c86e03118b53ddf4865b7995b1d197c3ef07ba29..bd4eefbc0bb960f8ddc1d238057e73a29a098f26 100644 --- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc +++ b/tensorflow/compiler/jit/kernels/parallel_check_op.cc @@ -64,7 +64,7 @@ class ParallelCheckOp : public OpKernel { ok = (diff <= tolerance); } if (ok) continue; - LOG(ERROR) << "Op " << def().name() << " fails equality at output " + LOG(ERROR) << "Op " << name() << " fails equality at output " << input_idx << " type " << DataTypeString(dtype) << " element " << i << ": std_val=" << p0[i] << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); @@ -75,7 +75,7 @@ class ParallelCheckOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - VLOG(1) << "Compute " << def().name(); + VLOG(1) << "Compute " << name(); const int num_pairs = ctx->num_inputs() / 2; for (int i = 0; i < num_pairs; ++i) { CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); @@ -113,7 +113,7 @@ class ParallelCheckOp : public OpKernel { LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); } if (failed > 0) { - LOG(ERROR) << "check failed for " << def().name() << " output " << i + LOG(ERROR) << "check failed for " << name() << " output " << i << " num_elts: " << num_elts; legacy_flags::ParallelCheckOpFlags* flags = legacy_flags::GetParallelCheckOpFlags(); @@ -121,7 +121,7 @@ class ParallelCheckOp : public OpKernel { LOG(QFATAL) << "failfast on first parallel-check failure"; } } else { - VLOG(1) << "check passed for " << def().name() << " output " << i + VLOG(1) << "check passed for " << name() << " output " << i << " num_elts: " << num_elts; } diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc index 29c5ff724299ec84d31268c4227259ec02d10742..2b77e5aaf4e0983354c14a4e20656af0e0e4f84b 100644 --- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/allocator.h" @@ -149,6 +151,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { xla::ExecutionOptions execution_options; *execution_options.mutable_shape_with_output_layout() = kernel->xla_output_shape; + *execution_options.mutable_debug_options() = + xla::legacy_flags::GetDebugOptionsFromFlags(); Env* env = Env::Default(); auto start_time = env->NowMicros(); VLOG(1) << "Executing XLA Computation..."; @@ -202,11 +206,14 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { // Apply variable updates, if any. VLOG(2) << "Applying variable updates"; - for (int i = 0; i < kernel->variable_updates.size(); ++i) { - const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i]; + for (int i = 0; i < kernel->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); + TensorShape write_shape; + OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape)); + // This code is very close to being a clone of AssignVariableOp, but the // key difference is that the contents of an XLA device tensor cannot be // copied safely; instead we must use @@ -214,26 +221,27 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not // a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - PersistentTensor unused; - Tensor* tmp; - TF_RETURN_IF_ERROR(ctx->allocate_persistent( - write.type, write.shape, &unused, &tmp)); - *(*ptr)->tensor() = *tmp; - return Status::OK(); - })); + OP_REQUIRES_OK(ctx, + LookupOrCreateResource( + ctx, HandleFromInput(ctx, write.input_index), &variable, + [this, ctx, &write, &write_shape](Var** ptr) { + *ptr = new Var(write.type); + PersistentTensor unused; + Tensor* tmp; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + write.type, write_shape, &unused, &tmp)); + *(*ptr)->tensor() = *tmp; + return Status::OK(); + })); core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, errors::Internal("Mismatched type in variable write")); - if (!variable->tensor()->shape().IsSameSize(write.shape)) { + if (!variable->tensor()->shape().IsSameSize(write_shape)) { PersistentTensor unused; Tensor* tmp; - OP_REQUIRES_OK(ctx, ctx->allocate_persistent(write.type, write.shape, + OP_REQUIRES_OK(ctx, ctx->allocate_persistent(write.type, write_shape, &unused, &tmp)); *variable->tensor() = *tmp; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index f1fef85f994a5f1f7514a5cb8b8b339706c7d998..77b45aa11e2e71f206bea4fbf08ed686ec6bb649 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" @@ -162,10 +163,12 @@ Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { return Status::OK(); } -// Does `node` have a DT_RESOURCE typed argument? -bool HasResourceArgument(const Node& node) { +// Tests whether `node` has a DT_RESOURCE typed input or output. +bool HasResourceInputOrOutput(const Node& node) { return std::find(node.input_types().begin(), node.input_types().end(), - DT_RESOURCE) != node.input_types().end(); + DT_RESOURCE) != node.input_types().end() || + std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); } Status FindCompilationCandidates( @@ -193,9 +196,10 @@ Status FindCompilationCandidates( << ": " << node->type_string(); continue; } - if (!registration->compile_resource_ops && HasResourceArgument(*node)) { - VLOG(2) << "Compilation rejected node: resource argument " << node->name() - << ": " << node->type_string(); + if (!registration->compile_resource_ops && + HasResourceInputOrOutput(*node)) { + VLOG(2) << "Compilation rejected node: resource input/output " + << node->name() << ": " << node->type_string(); continue; } if (node->type_string() == "While" && @@ -253,6 +257,11 @@ Status MarkForCompilationPass::Run( ®istration)) { return false; } + + // Don't compile control trigger nodes. We won't preserve their deadness + // semantics correctly, so it's safest not to compile them. + if (node->IsControlTrigger()) return false; + // If this device requires a JIT, we must say yes. if (registration->requires_compilation) return true; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 9f30e12e0e30fef6b4bcd0ea3c091842b008c29a..4b88da27a188ed4fa6125b3e7a84034efb1a0ec1 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" -#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -455,5 +457,39 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { EXPECT_EQ(clusters["B"], clusters["C"]); } +REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float"); +REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource"); + +namespace { + +class DummyOp : public XlaOpKernel { + using XlaOpKernel::XlaOpKernel; + void Compile(XlaOpKernelContext* ctx) override {} +}; + +REGISTER_XLA_OP(Name("ResourceInput"), DummyOp); +REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp); + +} // namespace + +TEST(XlaCompilationTest, Resources) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + // We should not form clusters with resource ops by default. + Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); + Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); + ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + TF_EXPECT_OK(builder.ToGraph(graph.get())); + } + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 8d1fa03cc0d74f3a61b3e2e1d6f2af07c0bcd23f..e5787ca4c8cff436e4404b8488970248b24a5eda 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -1,32 +1,20 @@ licenses(["notice"]) # Apache 2.0 package( - default_visibility = [ - "//tensorflow/compiler/tf2xla:internal", - ], + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) cc_library( name = "xla_ops", - srcs = [ - "xla_ops.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + srcs = ["xla_ops.cc"], + deps = ["//tensorflow/core:framework"], alwayslink = 1, ) cc_library( name = "parallel_check_op", srcs = ["parallel_check_op.cc"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + deps = ["//tensorflow/core:framework"], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 63ca77f9a912acce2078f3da43d64f2e10049380..3c52316ccef0023472b2e888e0c31b07fc00e694 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -148,7 +148,8 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.kind = XlaCompiler::Argument::kConstant; arg.type = input.dtype(); - arg.shape = input.shape(); + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); arg.constant_value = input; ++input_num; } @@ -169,7 +170,8 @@ Status BuildArguments(int num_constant_args, arg.constant_value = input; } arg.type = input.dtype(); - arg.shape = input.shape(); + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); ++input_num; } @@ -182,19 +184,21 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.name = variable_args[variable_id].name; + arg.kind = XlaCompiler::Argument::kVariable; if (variable_args[variable_id].present) { const Tensor& value = variable_args[variable_id].value; - arg.kind = XlaCompiler::Argument::kVariable; arg.type = value.dtype(); - arg.shape = value.shape(); + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape)); + arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since // they are meaningless. However, it is legal to assign to a resource // variable for the first time inside the XLA computation, so we do permit // uninitialized variables. - arg.kind = XlaCompiler::Argument::kUninitializedVariable; + arg.initialized = false; arg.type = DT_INVALID; - arg.shape = TensorShape(); + arg.shape = xla::Shape(); } ++input_num; } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 5e336c5287bd9e2067e93cd8db8a5a1b62b62bd2..615e2230f42f63f893ad645e1ab9513d6c30abf5 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -31,9 +31,11 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index f329e83e14dfce68eff3feb720c1603bd36fa7d6..0ab81ebd5ffec0b3dd6aee509a6d4d2b41d156db 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -137,7 +137,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(result.status()); return; } - const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie()); + const void* src_ptr = result.ValueOrDie()->InternalData(); void* dst_ptr = DMAHelper::base(cpu_tensor); size_t total_bytes = cpu_tensor->TotalBytes(); memcpy(dst_ptr, src_ptr, total_bytes); diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD index 9bc706abdf646a32da734906cada727d949eee21..bc7c25c12056332a8b74077d9f73ea551e8bbbee 100644 --- a/tensorflow/compiler/plugin/executor/BUILD +++ b/tensorflow/compiler/plugin/executor/BUILD @@ -11,12 +11,14 @@ cc_library( "*.h", ]), deps = [ + "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_jit_headers_lib", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:xla_headers_lib", - "//tensorflow/compiler/xla/service:hlo_evaluator", + "//tensorflow/compiler/xla/service", "//third_party/eigen3", "@local_config_cuda//cuda:cuda_headers", - "@protobuf//:protobuf_headers", + "@protobuf_archive//:protobuf_headers", ], ) diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc index 893ff152f0c77c354be178818eaf9e8fc75feaa4..72fe7ba4519833e17314f8fef803ad0230713780 100644 --- a/tensorflow/compiler/plugin/executor/compiler.cc +++ b/tensorflow/compiler/plugin/executor/compiler.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/plugin/executor/compiler.h" #include "tensorflow/compiler/plugin/executor/executable.h" - #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -30,27 +29,23 @@ limitations under the License. #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/status_macros.h" - +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/strcat.h" -#include "tensorflow/core/lib/core/errors.h" +namespace xla { +namespace executorplugin { namespace se = ::perftools::gputools; namespace sep = ::perftools::gputools::executorplugin; -namespace port = ::perftools::gputools::port; - -namespace xla { -namespace executorplugin { /* * Run optimization passes on the module. The graph is transformed by * each pass in the optimization pipeline. The service subdirectory * contains useful optimization passes. */ -Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module, - HloDumper dump_hlo) { - HloPassPipeline pipeline("Executor", dump_hlo); +Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) { + HloPassPipeline pipeline("Executor"); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(false); @@ -67,13 +62,13 @@ Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module, } StatusOr> ExecutorCompiler::Compile( - std::unique_ptr hlo_module, HloDumper dump_hlo, + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); VLOG(1) << "Generate graph " << hlo_module->name(); - TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get(), dump_hlo)); + TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); // Typically you would visit the HLO graph, building up a compiled equivalent // In this case we are using an Hlo evaluator at execution time, so we don't @@ -88,7 +83,7 @@ StatusOr> ExecutorCompiler::Compile( StatusOr>> ExecutorCompiler::Compile( std::vector> hlo_modules, - HloDumper dump_hlos, std::vector stream_execs) { + std::vector stream_execs) { return tensorflow::errors::Unimplemented( "Compilation of multiple HLO modules is not supported on Executor."); @@ -97,7 +92,7 @@ StatusOr>> ExecutorCompiler::Compile( StatusOr>> ExecutorCompiler::CompileAheadOfTime( std::vector> hlo_modules, - HloDumper dump_hlo, const AotCompilationOptions& aot_options) { + const AotCompilationOptions& aot_options) { return tensorflow::errors::InvalidArgument( "AOT compilation not supported on Executor"); @@ -112,12 +107,11 @@ ExecutorCompiler::ShapeSizeBytesFunction() const { return ExecutorExecutable::ShapeSizeBytes; } - -} // namespace executorplugin -} // namespace xla - REGISTER_MODULE_INITIALIZER(executor_compiler, { xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() { return xla::MakeUnique(); }); }); + +} // namespace executorplugin +} // namespace xla diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/plugin/executor/compiler.h index 8fe591c8abd57933aafa6c82159b49aad45a42d5..d318eefc49f0f1983cf58802d56e71b799944b11 100644 --- a/tensorflow/compiler/plugin/executor/compiler.h +++ b/tensorflow/compiler/plugin/executor/compiler.h @@ -35,25 +35,23 @@ class ExecutorCompiler : public Compiler { StatusOr> Compile( std::unique_ptr hlo_module, - HloDumper dump_hlo, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( std::vector> hlo_module, - HloDumper dump_hlo, std::vector stream_exec) override; StatusOr>> CompileAheadOfTime( std::vector> module, - HloDumper dump_hlo, const AotCompilationOptions& options) override; + const AotCompilationOptions& options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; perftools::gputools::Platform::Id PlatformId() const override; private: - Status RunHloOptimization(HloModule* hlo_module, HloDumper dump_hlo); + Status RunHloOptimization(HloModule* hlo_module); TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler); }; diff --git a/tensorflow/compiler/plugin/executor/device.cc b/tensorflow/compiler/plugin/executor/device.cc index bbc39dc03f866c0b10c0e4ac46eddebda4bec87f..d902f9df6a50161dacf12a5b234c1304ead353d5 100644 --- a/tensorflow/compiler/plugin/executor/device.cc +++ b/tensorflow/compiler/plugin/executor/device.cc @@ -47,7 +47,12 @@ Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options, return Status::OK(); } -REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 110); +// Set priority to be below the default priority (50), so that Executor is not +// selected as a high priority device over other default devices. +// See constructor comments for Registrar in +// tensorflow/core/common_runtime/device_factory.h for a list of priority for +// devices. +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 40); // Kernel registrations diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc index 92a517ba533cb073dac9b37179825d089e29f3ab..4f1f0d99f9730443f64bc58c16453b195b388ca1 100644 --- a/tensorflow/compiler/plugin/executor/executable.cc +++ b/tensorflow/compiler/plugin/executor/executable.cc @@ -15,18 +15,16 @@ limitations under the License. #include "tensorflow/compiler/plugin/executor/executable.h" #include "tensorflow/compiler/plugin/executor/executor.h" - -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" - #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/shape_util.h" -namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::executorplugin; - namespace xla { namespace executorplugin { +namespace se = ::perftools::gputools; +namespace sep = ::perftools::gputools::executorplugin; + ExecutorExecutable::ExecutorExecutable(std::unique_ptr hlo_module) : Executable(std::move(hlo_module), ShapeSizeBytes) {} @@ -36,7 +34,7 @@ static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor const Literal& literal) { int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); void* buf = executor->Allocate(size); - const void* src = LiteralUtil::InternalData(literal); + const void* src = literal.InternalData(); memcpy(buf, src, size); return se::DeviceMemoryBase(buf, size); } @@ -49,13 +47,14 @@ static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExecutorExecutor* executor } else { int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*))); void** buf = reinterpret_cast(executor->Allocate(size)); + void** buf_rc = buf; for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) { se::DeviceMemoryBase out = AllocateSingleOutput(executor, literal.tuple_literals(n)); *buf++ = out.opaque(); } - return se::DeviceMemoryBase(buf, size); + return se::DeviceMemoryBase(buf_rc, size); } } @@ -86,19 +85,18 @@ StatusOr ExecutorExecutable::ExecuteOnStream( for (int64 p = 0; p < computation->num_parameters(); p++) { // Create the input literal for the parameter HloInstruction* param = computation->parameter_instruction(p); - arg_literals.emplace_back(LiteralUtil::CreateFromShape(param->shape())); + arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); arg_literals_ptrs.push_back(arg_literals.back().get()); // Copy in the data from the stream_executor buffers - void* buffer = LiteralUtil::MutableInternalData(arg_literals.back().get()); + void* buffer = arg_literals.back()->MutableInternalData(); memcpy(buffer, arguments[p].opaque(), ShapeUtil::ByteSizeOf(param->shape())); } // Execute the graph using the evaluator HloEvaluator evaluator; - std::unique_ptr output; - TF_ASSIGN_OR_RETURN(output, + TF_ASSIGN_OR_RETURN(std::unique_ptr output, evaluator.Evaluate(computation, arg_literals_ptrs)); // Copy the result into the return buffer diff --git a/tensorflow/compiler/plugin/executor/executor.cc b/tensorflow/compiler/plugin/executor/executor.cc index e72c2711f794792fd4d7834b07eee5d983dff0a0..908b996bc95ac8d36f6c5577857b1a3a3826c3d4 100644 --- a/tensorflow/compiler/plugin/executor/executor.cc +++ b/tensorflow/compiler/plugin/executor/executor.cc @@ -14,14 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/plugin/executor/executor.h" -#include "tensorflow/compiler/plugin/executor/platform_id.h" - -#include "tensorflow/compiler/xla/status_macros.h" #include #include -namespace se = ::perftools::gputools; +#include "tensorflow/compiler/plugin/executor/platform_id.h" +#include "tensorflow/compiler/xla/status_macros.h" namespace perftools { namespace gputools { @@ -37,10 +35,7 @@ ExecutorExecutor::ExecutorExecutor(const PluginConfig &plugin_config) ExecutorExecutor::~ExecutorExecutor() {} -void *ExecutorExecutor::Allocate(uint64 size) { - void *buf = new char[size]; - return buf; -} +void *ExecutorExecutor::Allocate(uint64 size) { return new char[size]; } void *ExecutorExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset_bytes, @@ -126,8 +121,7 @@ DeviceDescription *ExecutorExecutor::PopulateDeviceDescription() const { builder.set_device_memory_size(static_cast(4) * 1024 * 1024 * 1024); builder.set_clock_rate_ghz(static_cast(CLOCKS_PER_SEC) / 1e9); - auto built = builder.Build(); - return built.release(); + return builder.Build().release(); } } // namespace executorplugin diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc index b59d20a7791f1ed2df2f35c6186e34e64fe4b248..51c5deeea5d5fd03d0fb99d4f33413c7bf4abe0f 100644 --- a/tensorflow/compiler/plugin/executor/transfer_manager.cc +++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc @@ -70,13 +70,13 @@ Status ExecutorTransferManager::TransferLiteralFromDevice( } *literal->mutable_shape() = device_shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal); + literal->Reserve(ShapeUtil::ElementsIn(device_shape)); TF_RETURN_IF_ERROR(TransferBufferFromDevice( executor, source, ShapeUtil::ByteSizeOf(device_shape), - LiteralUtil::MutableInternalData(literal))); + literal->MutableInternalData())); if (!ShapeUtil::Equal(literal_shape, device_shape)) { literal->Swap( - LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); + literal->Relayout(literal_shape.layout()).get()); } TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); return Status::OK(); @@ -134,7 +134,7 @@ Status ExecutorTransferManager::TransferLiteralToDevice( } return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), - LiteralUtil::InternalData(literal), + literal.InternalData(), destination); } @@ -147,6 +147,11 @@ Status ExecutorTransferManager::TransferLiteralToInfeed( return Status::OK(); } +Status ExecutorTransferManager::TransferBufferToInfeed( + se::StreamExecutor* executor, int64 size, const void* source) { + return Unimplemented("Transfer to Infeed"); +} + Status ExecutorTransferManager::TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) { diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.h b/tensorflow/compiler/plugin/executor/transfer_manager.h index 22142cd778a0aeccb6c393bdc1593e6213de858a..7a42e5a2d7542eaad7f8f90f011c65a9c526cc11 100644 --- a/tensorflow/compiler/plugin/executor/transfer_manager.h +++ b/tensorflow/compiler/plugin/executor/transfer_manager.h @@ -55,6 +55,9 @@ class ExecutorTransferManager : public TransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, const void* source) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4bbb2767ac033dd9995cad37886d476fc87618da..c693f58f8bddb7703d10e41afb4b666d92c25823 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -40,7 +40,9 @@ py_library( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", + "//tensorflow/python:random_seed", "//tensorflow/python:variables", + "//third_party/py/numpy", ], ) @@ -174,6 +176,11 @@ tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], + # TODO(b/62962492): Test fails with assertion error. + tags = [ + "manual", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -323,7 +330,7 @@ tf_xla_py_test( tf_xla_py_test( name = "reverse_ops_test", - size = "small", + size = "medium", srcs = ["reverse_ops_test.py"], deps = [ ":xla_test", @@ -346,6 +353,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "segment_reduction_ops_test", + size = "small", + srcs = ["segment_reduction_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:math_ops_gen", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "spacetobatch_op_test", size = "medium", @@ -360,6 +381,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "stack_ops_test", + size = "small", + srcs = ["stack_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "tensor_array_ops_test", size = "small", @@ -455,6 +489,11 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", ], + # TODO(b/62961789): Test fails with SIGABRT + tags = [ + "manual", + "notap", + ], ) cc_library( diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index a5c5885b4284aee167ae4cb18f7e42820c6d251d..9a93b3216404d8ed21fd6c57757bec1730c119b4 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -49,9 +49,11 @@ class AdagradOptimizerTest(XLATestCase): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + float_rtol=1e-5) def testTensorLearningRate(self): for dtype in self.float_types: @@ -73,9 +75,11 @@ class AdagradOptimizerTest(XLATestCase): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + float_rtol=1e-5) def testSharing(self): for dtype in self.float_types: @@ -107,9 +111,11 @@ class AdagradOptimizerTest(XLATestCase): ada_update1.run() # Validate updated params (the same as with only 1 Adagrad). self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + float_rtol=1e-5) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 7221a0a3c745f939b88cae0f66af2421922dcd68..0bdbf53c39f0bf35943646d9f11a11bbcfa2d6fe 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -113,6 +113,14 @@ class BinaryOpsTest(XLATestCase): np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) + self._testBinary( + gen_nn_ops._selu_grad, + np.array([1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype), + expected=np.array( + [1.158099340847, 2.7161986816948, 4.67429802254, + 4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype)) + self._testBinary( gen_nn_ops._relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), @@ -555,17 +563,18 @@ class BinaryOpsTest(XLATestCase): self._testBinary( math_ops.matmul, np.array( - [[[[1000, 100], [10, 1]], [[2000, 200], [20, 2]]], - [[[3000, 300], [30, 3]], [[4000, 400], [40, 4]]]], + [[[[7, 13], [10, 1]], [[2, 0.25], [20, 2]]], + [[[3, 5], [30, 3]], [[0.75, 1], [40, 4]]]], dtype=np.float32), np.array( [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[11, 22], [33, 44]], [[55, 66], [77, 88]]]], dtype=np.float32), expected=np.array( - [[[[1300, 2400], [13, 24]], [[11400, 13600], [114, 136]]], - [[[42900, 79200], [429, 792]], [[250800, 299200], [2508, 2992]]]], + [[[[46, 66], [13, 24]], [[11.75, 14], [114, 136]]], + [[[198, 286], [429, 792]], [[118.25, 137.5], [2508, 2992]]]], dtype=np.float32)) + self._testBinary( math_ops.matmul, np.array([], dtype=np.float32).reshape((2, 2, 0)), @@ -581,7 +590,7 @@ class BinaryOpsTest(XLATestCase): # Regression test for b/31472796. if hasattr(np, "matmul"): - x = np.arange(0, 3 * 5 * 16 * 7, dtype=np.float32).reshape((3, 5, 16, 7)) + x = np.arange(0, 3 * 5 * 2 * 7, dtype=np.float32).reshape((3, 5, 2, 7)) self._testBinary( lambda x, y: math_ops.matmul(x, y, adjoint_b=True), x, x, diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 6b328fb618bf8b9174dce756487494994b8aea04..7e3871312c86530b6d3cb0bbacc16c25d3469832 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -134,9 +134,9 @@ class FtrlOptimizerTest(XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-2.60260963, -4.29698515]), var0.eval()) + np.array([-2.60260963, -4.29698515]), var0.eval(), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([-0.28432083, -0.56694895]), var1.eval()) + np.array([-0.28432083, -0.56694895]), var1.eval(), float_rtol=1e-5) def testFtrlwithoutRegularization2(self): for dtype in self.float_types: @@ -189,8 +189,10 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-7.66718769, -10.91273689]), var0.eval()) - self.assertAllClose(np.array([-0.93460727, -1.86147261]), var1.eval()) + self.assertAllClose(np.array([-7.66718769, -10.91273689]), var0.eval(), + rtol=1e-4) + self.assertAllClose(np.array([-0.93460727, -1.86147261]), var1.eval(), + rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: @@ -215,10 +217,47 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval()) - self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval()) + self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval(), + rtol=1e-5) + self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval(), + rtol=1e-5) + + def testFtrlWithL1_L2_L2Shrinkage(self): + """Test the new FTRL op with support for l2 shrinkage. + + The addition of this parameter which places a constant pressure on weights + towards the origin causes the gradient descent trajectory to differ. The + weights will tend to have smaller magnitudes with this parameter set. + """ + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + opt = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0, + l2_shrinkage_regularization_strength=0.1) + ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps FTRL + for _ in range(10): + ftrl_update.run() + + # Validate updated params + self.assertAllClose(np.array([-0.21931979, -0.40642974]), var0.eval(), + rtol=1e-4) + self.assertAllClose(np.array([-0.0282721, -0.07188385]), var1.eval(), + rtol=1e-4) - # When variables are intialized with Zero, FTRL-Proximal has two properties: + # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical # with GradientDescent. # 2. Without L1&L2 but with adaptive learning rate, FTRL-Proximal is idential @@ -233,8 +272,8 @@ class FtrlOptimizerTest(XLATestCase): with self.test_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) - self.assertAllClose(val0, val2) - self.assertAllClose(val1, val3) + self.assertAllClose(val0, val2, rtol=1e-4) + self.assertAllClose(val1, val3, rtol=1e-4) def testEquivGradientDescentwithoutRegularization(self): steps = 5 @@ -245,8 +284,8 @@ class FtrlOptimizerTest(XLATestCase): val2, val3 = self.equivGradientDescentTest_GradientDescentPart( steps, dtype) - self.assertAllClose(val0, val2) - self.assertAllClose(val1, val3) + self.assertAllClose(val0, val2, rtol=1e-5) + self.assertAllClose(val1, val3, rtol=1e-5) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 52290e63548910309a9b6b75b7a4642ebeed1efa..7c19a99c4eb4be3ca34b3ce949216e557b0a681d 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -376,7 +376,7 @@ class PoolGradTest(XLATestCase): self.assertAllClose( expected_input_gradient_vals.flatten(), actual.flatten(), - rtol=1e-5, + rtol=1e-4, atol=1e-6) self.assertShapeEqual(actual, inputs) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index d3821ad02e5aa9cbec6ba7fb940ee2246d38c81e..825fd9de2eb306234da36c691e0c7ca2e724dd5a 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -1434,6 +1434,23 @@ TEST_F(OpTest, EluGrad) { }); } +TEST_F(OpTest, Selu) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SeluGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Equal) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..260a04421b62310c109d8f0ea72875a50c234bb0 --- /dev/null +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for segment reduction ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class SegmentReductionOpsTest(XLATestCase): + """Test cases for segment reduction ops.""" + + def UnsortedSegmentSum(self, data, indices, num_segments): + with self.test_session() as sess, self.test_scope(): + d = array_ops.placeholder(data.dtype, shape=data.shape) + if isinstance(indices, int): + i = array_ops.placeholder(np.int32, shape=[]) + else: + i = array_ops.placeholder(indices.dtype, shape=indices.shape) + return sess.run( + math_ops.unsorted_segment_sum(d, i, num_segments), + {d: data, + i: indices}) + + def testUnsortedSegmentSum0DIndices1DData(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array( + [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5], + [0, 0, 0, 0, 0, 0]], + dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) + + def testUnsortedSegmentSum1DIndices1DData(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([1, 3, 2, 9], dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) + + def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): + for dtype in self.numeric_types: + data = np.array( + [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43], + [50, 51, 52, 53]], + dtype=dtype) + indices = np.array([8, 1, 0, 3, 7], dtype=np.int32) + num_segments = 10 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0], + [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], + [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self): + for dtype in self.numeric_types: + data = np.array( + [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43], + [50, 51, 52, 53]], + dtype=dtype) + indices = np.array([0, 1, 2, 0, 1], dtype=np.int32) + num_segments = 4 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33], + [0, 0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum2DIndices3DData(self): + for dtype in self.numeric_types: + data = np.array( + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], + [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], + [310, 311, 312]]], + dtype=dtype) + indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) + num_segments = 8 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[210, 211, 212], [110, 111, 112], [310, 311, 312], + [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301, + 302], [0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum1DIndices3DData(self): + for dtype in self.numeric_types: + data = np.array( + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], + [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], + [310, 311, 312]]], + dtype=dtype) + indices = np.array([3, 0, 2, 5], dtype=np.int32) + num_segments = 6 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], + [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]], + [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]], + dtype=dtype), y) + + def testUnsortedSegmentSumShapeError(self): + for dtype in self.numeric_types: + data = np.ones((4, 8, 7), dtype=dtype) + indices = np.ones((3, 2), dtype=np.int32) + num_segments = 4 + self.assertRaises(ValueError, + functools.partial(self.UnsortedSegmentSum, data, + indices, num_segments)) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index 9c3b86c84b2b92089da0dfc0070a4a7b8a03c81a..c013f4b50a4cf95be8028248c52b10b1c3be2bd3 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -228,34 +228,40 @@ class SpaceToBatchNDTest(XLATestCase): outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]], [[4, 41], [6, 61]]]) - def testDirect(self): + def testDirect0(self): # Test with zero-size remaining dimension. self._testDirect( input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]]) + def testDirect1(self): # Test with zero-size blocked dimension. self._testDirect( input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]]) + def testDirect2(self): # Test with padding up from zero size. self._testDirect( input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]]) + def testDirect3(self): self._testDirect( input_shape=[3, 3, 4, 5, 2], block_shape=[3, 4, 2], paddings=[[1, 2], [0, 0], [3, 0]]) + def testDirect4(self): self._testDirect( input_shape=[3, 3, 4, 5, 2], block_shape=[3, 4, 2, 2], paddings=[[1, 2], [0, 0], [3, 0], [0, 0]]) + def testDirect5(self): self._testDirect( input_shape=[3, 2, 2, 3, 4, 5, 2, 5], block_shape=[1, 1, 3, 4, 2, 2], paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]]) + def testDirect6(self): self._testDirect( input_shape=[3, 2, 2, 3, 4, 5, 2, 5], block_shape=[1, 1, 3, 4, 2, 2, 1], diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9c2279737ccee531d488d27ccdb0cafa1dc8fc --- /dev/null +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -0,0 +1,104 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.stack_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.platform import test + + +class StackOpTest(XLATestCase): + + def testStackPushPop(self): + with self.test_session(), self.test_scope(): + size = array_ops.placeholder(dtypes.int32) + v = array_ops.placeholder(dtypes.float32) + h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops._stack_push_v2(h, v) + with ops.control_dependencies([c]): + c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) + + def testStackPushPopSwap(self): + with self.test_session(), self.test_scope(): + a = np.arange(2000) + x = array_ops.placeholder(dtypes.float32) + h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True) + with ops.control_dependencies([c]): + c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + self.assertAllClose(a, c1.eval({x: a})) + + def testMultiStack(self): + with self.test_session(), self.test_scope(): + v = array_ops.placeholder(dtypes.float32) + h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops._stack_push_v2(h1, v) + with ops.control_dependencies([c1]): + c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) + h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="bar") + c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0) + with ops.control_dependencies([c2]): + c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + r = c1 + c2 + self.assertAllClose(9.0, r.eval({v: 4.0})) + + def testSameNameStacks(self): + """Different stacks with the same name do not interfere.""" + with self.test_session() as sess, self.test_scope(): + v1 = array_ops.placeholder(dtypes.float32) + v2 = array_ops.placeholder(dtypes.float32) + h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + + c1 = gen_data_flow_ops._stack_push_v2(h1, v1) + with ops.control_dependencies([c1]): + c2 = gen_data_flow_ops._stack_push_v2(h2, v2) + with ops.control_dependencies([c2]): + pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) + pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + + out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0}) + self.assertAllClose(out1, 4.0) + self.assertAllClose(out2, 5.0) + + def testCloseStack(self): + with self.test_session() as sess, self.test_scope(): + size = array_ops.placeholder(dtypes.int32) + h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops._stack_close_v2(h) + sess.run(c1, {size: 5}) + + def testPushCloseStack(self): + with self.test_session() as sess, self.test_scope(): + v = array_ops.placeholder(dtypes.float32) + h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops._stack_push_v2(h, v) + with ops.control_dependencies([c]): + c1 = gen_data_flow_ops._stack_close_v2(h) + sess.run(c1, {v: [[4.0, 5.0]]}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 27a29773053e08c755afce23c3257d96ce27a929..ac039e01623b954e291760fb9b50ef8eae3da7c1 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -57,11 +57,13 @@ class TensorArrayTest(xla_test.XLATestCase): r0 = w2.read(0) r1 = w2.read(1) r2 = w2.read(2) + flow = w2.flow - d0, d1, d2 = session.run([r0, r1, r2]) + d0, d1, d2, flow_val = session.run([r0, r1, r2, flow]) self.assertAllEqual([[4.0, 5.0]], d0) self.assertAllEqual([[1.0, 3.0]], d1) self.assertAllEqual([[7.0, -8.5]], d2) + self.assertAllEqual([], flow_val.shape) def _testTensorArrayWritePack(self, tf_dtype): with self.test_session(), self.test_scope(): @@ -139,7 +141,7 @@ class TensorArrayTest(xla_test.XLATestCase): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) - # Unpack a matrix into vectors + # Unpack a matrix into vectors. w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])) r0 = w1.read(0) r1 = w1.read(1) @@ -180,7 +182,7 @@ class TensorArrayTest(xla_test.XLATestCase): convert = _make_converter(tf_dtype) - # Split an empty vector + # Split an empty vector. lengths = constant_op.constant([0, 0, 0]) w0 = ta.split(convert([]), lengths=lengths) r0 = w0.read(0) @@ -192,7 +194,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(convert([]), d1) self.assertAllEqual(convert([]), d2) - # Split a vector + # Split a vector. ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) lengths = constant_op.constant([1, 1, 1]) @@ -206,7 +208,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(convert([2.0]), d1) self.assertAllEqual(convert([3.0]), d2) - # Split a matrix + # Split a matrix. ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) lengths = constant_op.constant([1, 1, 1]) @@ -319,27 +321,31 @@ class TensorArrayTest(xla_test.XLATestCase): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) - # Test writing the wrong datatype + # Test writing the wrong datatype. with self.assertRaisesOpError( "TensorArray dtype is float but op has dtype int32"): ta.write(-1, np.int32(7)).flow.eval() def testTensorArrayReadWrongIndexOrDataTypeFails(self): - with self.test_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3) - - w0 = ta.write(0, [[4.0, 5.0]]) - - # Test reading wrong datatype - r0_bad = gen_data_flow_ops._tensor_array_read_v3( - handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) - with self.assertRaisesOpError( - "TensorArray dtype is float but Op requested dtype double."): - r0_bad.eval() - - # Test reading from a different index than the one we wrote to - w0.read(1) + # Find two different floating point types, create an array of + # the first type, but try to read the other type. + if len(self.float_types) > 1: + dtype1 = self.float_types[0] + dtype2 = self.float_types[1] + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtype1, tensor_array_name="foo", size=3) + + w0 = ta.write(0, [[4.0, 5.0]]) + + # Test reading wrong datatype. + r0_bad = gen_data_flow_ops._tensor_array_read_v3( + handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) + with self.assertRaisesOpError("TensorArray dtype is "): + r0_bad.eval() + + # Test reading from a different index than the one we wrote to + w0.read(1) def testTensorArraySplitIncompatibleShapesFails(self): with self.test_session(), self.test_scope(): @@ -487,7 +493,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0 = w1.read(0) s0 = w1.concat() - # Test gradient accumulation between read(0), pack(), and concat() + # Test gradient accumulation between read(0), pack(), and concat(). with ops.control_dependencies([p0, r0, s0]): grad_r = gradients_impl.gradients( ys=[p0, r0, s0], @@ -536,7 +542,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0_1 = w.read(0) r1 = w.read(1) - # Test combined gradients + aggregation of read(0) + # Test combined gradients + aggregation of read(0). grad = gradients_impl.gradients( ys=[r0, r0_1, r1], xs=[value], @@ -573,13 +579,12 @@ class TensorArrayTest(xla_test.XLATestCase): [2000.0, -2000.0]], grad_vals[0]) - # TODO(phawkins): implement TensorArrayClose - # def testCloseTensorArray(self): - # with self.test_session() as session, self.test_scope(): - # ta = tensor_array_ops.TensorArray( - # dtype=dtypes.float32, tensor_array_name="foo", size=3) - # c1 = ta.close() - # session.run(c1) + def testCloseTensorArray(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c1 = ta.close() + session.run(c1) def testSizeTensorArray(self): with self.test_session(), self.test_scope(): @@ -588,17 +593,16 @@ class TensorArrayTest(xla_test.XLATestCase): s = ta.size() self.assertAllEqual(3, s.eval()) - # TODO(phawkins): implement TensorArrayClose - # def testWriteCloseTensorArray(self): - # with self.test_session(), self.test_scope(): - # ta = tensor_array_ops.TensorArray( - # dtype=dtypes.float32, - # tensor_array_name="foo", - # size=3, - # infer_shape=False) - # w0 = ta.write(0, [[4.0, 5.0]]) - # w1 = w0.write(1, [3.0]) - # w1.close().run() # Expected to run without problems + def testWriteCloseTensorArray(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + w0 = ta.write(0, [[4.0, 5.0]]) + w1 = w0.write(1, [3.0]) + w1.close().run() # Expected to run without problems # TODO(phawkins): implement while loops. # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): @@ -746,7 +750,7 @@ class TensorArrayTest(xla_test.XLATestCase): grad_b_t, = session.run([grad_b]) self.assertAllEqual(grad_b_t, g0) - # Test gradients calculated jointly + # Test gradients calculated jointly. joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b]) self.assertAllEqual(joint_grad_a_t, g0) self.assertAllEqual(joint_grad_b_t, g0) @@ -879,7 +883,7 @@ class TensorArrayTest(xla_test.XLATestCase): x = constant_op.constant([2.0, 3.0]) w = ta.unstack(x) r0 = w.read(0) - # calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0). + # Calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0). grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0]) grad_r0_vals = session.run(grad_r0)[0] self.assertAllEqual(grad_r0_vals, [1.0, 0.0]) @@ -929,7 +933,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0 = w.read(1) r1 = w.read(8) - # Test combined gradients + aggregation of read(0) + # Test combined gradients + aggregation of read(0). grad = gradients_impl.gradients( ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) read_vals, grad_vals = session.run([[r0, r1], grad]) @@ -953,7 +957,7 @@ class TensorArrayTest(xla_test.XLATestCase): w = ta.unstack(values) g = w.gather(indices) - # Test combined gradients + aggregation of read(0) + # Test combined gradients + aggregation of read(0). grad = gradients_impl.gradients( ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]]) g_vals, grad_vals = session.run([[g], grad]) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 51d8786ce3d7148e6863be7e1557a8bb23153d63..81ff18f3023c17f722632962dfa1cac60a7dfdc1 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -152,6 +152,16 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2]], dtype=dtype), expected=np.array([[0, 0.69314718]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.sin, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0.841478, 0.909302]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.cos, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0.540297, -0.41614]], dtype=dtype)) + # TODO(b/34703906): improve log1p implementation and make tolerance # tighter. self._assertOpOutputMatchesExpected( @@ -219,6 +229,11 @@ class UnaryOpsTest(XLATestCase): np.array([[-1, 0, 1]], dtype=dtype), expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.selu, + np.array([[-1, 0, 1]], dtype=dtype), + expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.relu, np.array([[-1, 1]], dtype=dtype), diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 70dacd9de4b95dfb77986dfaf177c16b758406f1..a6b59fc731e7556cbfa6e0c2c4f889b58568e622 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -54,6 +54,54 @@ class VariableOpsTest(XLATestCase): self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype), sess.run(y, {p: 1})) + def testSparseRead0DIndices(self): + for dtype in self.numeric_types: + init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + x = v.sparse_read(2) + self.assertAllClose(np.array([8, 9, 10, 11], dtype=dtype), sess.run(x)) + + def testSparseRead1DIndices(self): + for dtype in self.numeric_types: + init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + x = v.sparse_read([2, 1]) + self.assertAllClose( + np.array([[8, 9, 10, 11], [4, 5, 6, 7]], dtype=dtype), sess.run(x)) + + def testSparseRead2DIndices(self): + for dtype in self.numeric_types: + init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + x = v.sparse_read([[2, 1], [0, 2]]) + self.assertAllClose( + np.array( + [[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2, 3], [8, 9, 10, + 11]]], + dtype=dtype), sess.run(x)) + + def testSparseRead2DIndices3DTensor(self): + for dtype in self.numeric_types: + init = np.array( + [[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]], + [[20, 21, 22], [23, 24, 25]], [[30, 31, 32], [33, 34, 35]]], + dtype=dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + x = v.sparse_read([[2, 1], [3, 0]]) + self.assertAllClose( + np.array( + [[[[20, 21, 22], [23, 24, 25]], [[10, 11, 12], [13, 14, 15]]], + [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], + dtype=dtype), sess.run(x)) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" with self.test_session() as session: diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 93c484ca7a0d04654371724aac905eb055c82b05..60e68db2d689f502481c45a748f6e6abac2b69e8 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -42,6 +42,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -101,6 +102,7 @@ cc_test( "//tensorflow/cc:ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -152,7 +154,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", ], ) @@ -165,13 +166,10 @@ cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", ], ) @@ -203,6 +201,59 @@ cc_library( ], ) +cc_library( + name = "functionalize_control_flow", + srcs = ["functionalize_control_flow.cc"], + hdrs = ["functionalize_control_flow.h"], + deps = [ + "//tensorflow/compiler/jit:graph_to_functiondef", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + ], +) + +cc_test( + name = "functionalize_control_flow_test", + srcs = ["functionalize_control_flow_test.cc"], + deps = [ + ":functionalize_control_flow", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/compiler/tf2xla/cc:functional_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:ops", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..599265ba449c88baef1671b1c81d96d1715ce5f2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -0,0 +1,44 @@ +package( + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") + +tf_gen_op_wrapper_cc( + name = "functional_ops_gen", + include_internal_ops = 1, + out_ops_file = "ops/functional_ops", + deps = ["//tensorflow/compiler/tf2xla/ops:functional_ops"], +) + +cc_library( + name = "functional_ops", + srcs = ["ops/functional_ops.cc"], + hdrs = ["ops/functional_ops.h"], + deps = [ + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 36a6c90af4f4e8e0618ea5a5432365d4d90e51e4..d98cf829bb6819ea2efc3217a9539a88b570bc4b 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -81,6 +81,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"Split", "split_dim"}, {"SplitV", "split_dim"}, {"SplitV", "size_splits"}, + {"StackV2", "max_size"}, {"StridedSlice", "begin"}, {"StridedSlice", "end"}, {"StridedSlice", "strides"}, diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c7a2046aa549beb2de58d21f517363d4fe8aea7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -0,0 +1,583 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/graph_to_functiondef.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/control_flow.h" + +namespace tensorflow { + +namespace { + +const char* const kArgOp = "_Arg"; +const char* const kRetValOp = "_Retval"; + +// Information about a loop argument. +struct Arg { + // Every loop argument has an Enter node. + Node* enter; + + // Is the loop argument a loop-invariant value? Taken from the `is_constant` + // attribute on the Enter node. + bool is_loop_invariant; + + // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant + // arguments must have all of the following nodes: + Node* merge = nullptr; + Node* switch_node = nullptr; + Node* next_iteration = nullptr; + Node* exit = nullptr; +}; + +// Information about a loop frame. +struct Frame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + Frame* parent = nullptr; + int num_children = 0; + + // Arguments to this loop. + std::vector args; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + Node* loop_cond = nullptr; + + // Set of nodes that belong to the loop frame. + std::unordered_set nodes; +}; + +// Copies a subgraph from `graph` to `output` by performing a reverse DFS +// starting at nodes in vector `stack`. +// `node_map` is a vector indexed by source node ID to dest nodes. +// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` +// before the traversal clients can cut the graph. Returns an error if the +// traversal leaves 'frame'; the client must add enough nodes to `node_map` to +// cut the graph and prevent the traversal from escaping. +// +// `squash_src_outputs` contains a bool for each source node ID. If true, then +// the source output on that node will be replaced by zero when copied. This is +// used when replacing a Switch node with an _Arg node. The output we are +// taking from the Switch node was not necessarily the first output, but _Arg +// nodes only have one output. By adding the Switch node to `squash_src_outputs` +// we rewrite the src_output of the corresponding edge to be 0. +Status CopySubgraph(const Graph& graph, const Frame& frame, + std::vector stack, + const std::vector& squash_src_outputs, + std::vector* node_map, Graph* output) { + std::vector visited(graph.num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + VLOG(3) << "Copying node " << n->name(); + + if (visited[n->id()]) continue; + visited[n->id()] = true; + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (frame.nodes.find(src) == frame.nodes.end()) { + // We traversed out of the loop frame, without encountering a cut node. + return errors::Internal("Graph traversal of loop frame ", frame.name, + " escaped frame at ", src->name(), + " without encountering an argument node."); + } + if ((*node_map)[src->id()] == nullptr) { + (*node_map)[src->id()] = output->CopyNode(src); + stack.push_back(src); + } + Node* src_copy = (*node_map)[e->src()->id()]; + int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output(); + Node* dst_copy = (*node_map)[e->dst()->id()]; + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + return Status::OK(); +} + +Status BuildArgNode(Graph* graph, DataType type, int index, Node** arg_node) { + NodeDef arg_def; + NodeDefBuilder builder(strings::StrCat("_Arg", index), kArgOp); + builder.Attr("T", type); + builder.Attr("index", index); + TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); + Status status; + *arg_node = graph->AddNode(arg_def, &status); + return status; +} + +Status BuildRetvalNode(Graph* graph, DataType type, int index, + Node** retval_node) { + NodeDef ret_def; + ret_def.set_op(kRetValOp); + ret_def.set_name(strings::StrCat("_Retval", index)); + AddNodeAttr("T", type, &ret_def); + AddNodeAttr("index", index, &ret_def); + Status status; + *retval_node = graph->AddNode(ret_def, &status); + return status; +} + +// Builds a graph for the loop condition. +Status BuildLoopCondition(const Graph& graph, Frame* frame, + std::unique_ptr* cond_output) { + VLOG(2) << "Building loop condition for " << frame->name; + *cond_output = xla::MakeUnique(graph.op_registry()); + Graph* output = cond_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + Node* arg_node; + TF_RETURN_IF_ERROR( + BuildArgNode(output, arg.enter->input_type(0), i, &arg_node)); + if (arg.is_loop_invariant) { + node_map[arg.enter->id()] = arg_node; + } else { + node_map[arg.merge->id()] = arg_node; + } + } + + // Build a Retval node for the loop condition. The LoopCond nodes are always + // boolean because of the type constraints on the LoopCond op. + TF_RETURN_IF_ERROR( + BuildRetvalNode(output, DT_BOOL, 0, &node_map[frame->loop_cond->id()])); + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, {frame->loop_cond}, + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +// Builds a graph for the loop body. +Status BuildLoopBody(const Graph& graph, Frame* frame, + DataTypeVector* arg_types, + std::unique_ptr* body_output) { + VLOG(2) << "Building loop body for " << frame->name; + *body_output = xla::MakeUnique(graph.op_registry()); + Graph* output = body_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + std::vector next_iterations; + next_iterations.reserve(frame->args.size()); + arg_types->reserve(frame->args.size()); + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + DataType dtype = arg.enter->input_type(0); + arg_types->push_back(dtype); + Node* arg_node; + TF_RETURN_IF_ERROR(BuildArgNode(output, dtype, i, &arg_node)); + + if (dtype == DT_RESOURCE) { + // The convention of the XLA bridge is that resource variable arguments + // are only inputs to the loop body and have no corresponding output. + // TODO(b/37741920): change the convention so that DT_RESOURCE variables + // are both inputs and outputs, and then remove this case. + TF_RET_CHECK(arg.is_loop_invariant); + node_map[arg.enter->id()] = arg_node; + } else { + Node* retval_node; + TF_RETURN_IF_ERROR(BuildRetvalNode(output, dtype, i, &retval_node)); + + if (arg.is_loop_invariant) { + // Argument is loop-invariant. Forward it from the Arg to the Retval. + node_map[arg.enter->id()] = arg_node; + output->AddEdge(arg_node, 0, retval_node, 0); + } else { + // Argument is loop-varying. + node_map[arg.switch_node->id()] = arg_node; + // The Switch node has two outputs, but _Arg only has one. This tells + // the CopySubgraph function to rewrite the output number of edges from + // the _Arg node to be 0 rather than copying the output number from the + // Switch node. + squash_src_outputs[arg.switch_node->id()] = true; + node_map[arg.next_iteration->id()] = retval_node; + next_iterations.push_back(arg.next_iteration); + } + } + } + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, std::move(next_iterations), + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +Status FunctionalizeLoop(Graph* graph, Frame* frame, + FunctionLibraryDefinition* library) { + VLOG(2) << "Frame " << frame->name << " before: " + << dump_graph::DumpGraphToFile("functionalize_before", *graph); + + // Split loop-varying Enter nodes with multiple successors. If the same + // Tensor is fed as input to multiple loop arguments, we may end up with a + // shared Enter node. We clone Enter nodes with multiple successors to + // maintain the invariant of a unique Enter node per argument of the final + // loop. + std::vector args; + for (const Arg& arg : frame->args) { + if (arg.is_loop_invariant) { + args.push_back(arg); + } else { + std::vector edges(arg.enter->out_edges().begin(), + arg.enter->out_edges().end()); + for (int i = 0; i < edges.size(); ++i) { + if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { + continue; + } + TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); + Arg new_arg; + new_arg.is_loop_invariant = false; + if (i == 0) { + new_arg.enter = arg.enter; + } else { + new_arg.enter = graph->CopyNode(arg.enter); + frame->nodes.insert(new_arg.enter); + for (Edge const* e : arg.enter->in_edges()) { + graph->AddEdge(e->src(), e->src_output(), new_arg.enter, + e->IsControlEdge() ? Graph::kControlSlot : 0); + } + Node* dst = edges[i]->dst(); + int dst_input = edges[i]->dst_input(); + graph->RemoveEdge(edges[i]); + graph->AddEdge(new_arg.enter, 0, dst, dst_input); + } + args.push_back(new_arg); + } + } + } + frame->args = std::move(args); + + // Order the arguments so that: + // a) resource variables are last, and + // b) sort lexicographically by name (for deterministic output). + std::sort(frame->args.begin(), frame->args.end(), + [](const Arg& a, const Arg& b) { + bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE); + bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE); + return std::tie(a_is_resource, a.enter->name()) < + std::tie(b_is_resource, b.enter->name()); + }); + + if (frame->loop_cond == nullptr) { + return errors::InvalidArgument("Loop ", frame->name, + " has no LoopCond node"); + } + + // Find the set of Switch nodes that are successors of the LoopCond. + std::unordered_set switches; + for (const Edge* edge : frame->loop_cond->out_edges()) { + if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && + edge->dst_input() == 1) { + switches.insert(edge->dst()); + } + } + + // For each non-constant argument, looks for the following pattern of nodes: + // Enter ----> Merge --------> Switch --> Exit + // ^ ^ + // | | + // NextIteration LoopCond + // ^ ^ + // | | + // ... ... + for (Arg& arg : frame->args) { + if (!arg.is_loop_invariant) { + // Follow the edge from the Enter to Merge. + const Edge* enter_merge = nullptr; + for (const Edge* e : arg.enter->out_edges()) { + // Ignore control-edges to the sink node. These are allowed by the + // graph invariants, although probably they should have been stripped + // off earlier. + if (e->IsControlEdge() && e->dst()->IsSink()) { + continue; + } + if (enter_merge != nullptr) { + return errors::Internal( + "Enter node for loop-varying argument ", arg.enter->name(), + " has multiple successors: ", enter_merge->dst()->name(), " and ", + e->dst()->name()); + } + enter_merge = e; + } + if (enter_merge == nullptr) { + return errors::Internal("Enter node for loop-varying argument ", + arg.enter->name(), " has zero successors"); + } + arg.merge = enter_merge->dst(); + if (!IsMerge(arg.merge)) { + return errors::InvalidArgument( + "Successor of Enter node for loop-varying argument ", + arg.merge->name(), + " is not a Merge node; got: ", arg.merge->type_string()); + } + + // Find the NextIteration from the merge. There should be two inputs to + // the Merge and the NextIteration should be the other input. + if (arg.merge->input_types().size() != 2) { + return errors::InvalidArgument( + "Unexpected number of inputs to Merge node for loop-varying " + "argument ", + arg.merge->name(), "; expected 2, got ", + arg.merge->input_types().size()); + } + TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), + &arg.next_iteration)); + if (!IsNextIteration(arg.next_iteration)) { + return errors::InvalidArgument( + "Expected NextIteration node as input to Merge node; got node ", + arg.next_iteration->name(), " with kind ", + arg.next_iteration->type_string()); + } + + // Find the Switch successor of the Merge. There should be exactly one + // Switch node that is a successor of both the Merge and the LoopCond. + for (const Edge* edge : arg.merge->out_edges()) { + if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && + switches.find(edge->dst()) != switches.end()) { + if (arg.switch_node != nullptr) { + return errors::InvalidArgument("Duplicate Switch successors to ", + arg.merge->name()); + } + arg.switch_node = edge->dst(); + } + } + if (arg.switch_node == nullptr) { + return errors::InvalidArgument("Missing Switch successor to ", + arg.merge->name()); + } + + // Find the Exit successor of the Switch. + for (const Edge* edge : arg.switch_node->out_edges()) { + if (edge->src_output() == 0 && IsExit(edge->dst())) { + if (arg.exit != nullptr) { + return errors::InvalidArgument("Duplicate Exit successors to ", + arg.switch_node->name()); + } + arg.exit = edge->dst(); + } + } + if (arg.exit == nullptr) { + return errors::InvalidArgument("Missing Exit successor to ", + arg.switch_node->name()); + } + } + } + + // Builds the condition and body functions. + std::unique_ptr cond_graph; + TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + DataTypeVector arg_types; + std::unique_ptr body_graph; + TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + + VLOG(2) << "Frame " << frame->name << " condition: " + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph) + << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); + + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + NameAttrList cond_name; + cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); + NameAttrList body_name; + body_name.set_name(strings::StrCat("_functionalize_body_", id)); + FunctionDef cond_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); + + TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + + // Builds a While operator. + NodeDef while_def; + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + builder.Attr("T", arg_types); + builder.Attr("cond", cond_name); + builder.Attr("body", body_name); + std::vector inputs; + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + inputs.push_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), arg_types[i])); + } + } + builder.Input(inputs); + TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); + + Status status; + Node* while_node = graph->AddNode(while_def, &status); + if (!status.ok()) { + return status; + } + + // Copies edges to the Enter nodes and from the Exit nodes onto the While. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph->AddControlEdge(in_edge->src(), while_node); + } else { + graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); + } + + if (!arg.is_loop_invariant) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + int src_output = + dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph->AddEdge(while_node, src_output, dst, dst_input); + } + } + } + + // Remove the old nodes from the graph, and add the while node to the parent + // frame. + for (Node* node : frame->nodes) { + graph->RemoveNode(node); + } + frame->parent->nodes.insert(while_node); + + VLOG(2) << "Frame " << frame->name << " after: " + << dump_graph::DumpGraphToFile("functionalize_after", *graph); + + return Status::OK(); +} + +} // namespace + +// Transformation that converts Tensorflow's graph control flow constructs into +// functional equivalents. +Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(2) << "FunctionalizeControlFlow: " + << dump_graph::DumpGraphToFile("functionalize_initial", *graph); + // Note: BuildControlFlowInfo() requires that the graph's source node is + // connected to all source nodes in the graph. Many graphs violate this + // invariant. + std::vector cf_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); + + // Builds Frames, indexed by name. + std::unordered_map frames; + for (Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + + VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name + << " frame: " << (cf.frame ? cf.frame->name() : "---") + << " parent_frame: " + << (cf.parent_frame ? cf.parent_frame->name() : "---"); + TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); + + Frame& frame = frames[cf.frame_name]; + Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + ++parent->num_children; + } else if (frame.parent != parent) { + return errors::InvalidArgument("Mismatched parent frames for ", + cf.frame->id(), ": ", parent->name, " vs ", + frame.parent->name); + } + + if (IsEnter(node)) { + Arg arg; + arg.enter = node; + TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", + &arg.is_loop_invariant)); + frame.args.push_back(arg); + } else if (IsLoopCond(node)) { + if (frame.loop_cond) { + return errors::InvalidArgument( + "Loop ", cf.frame_name, + " has more than one LoopCond node: ", node->name(), " and ", + frame.loop_cond->name()); + } + frame.loop_cond = node; + } + frame.nodes.insert(node); + } + + // Adds frames with no children (i.e., the innermost frames) to a worklist. + std::deque worklist; + for (auto& frame : frames) { + if (frame.second.num_children == 0) { + worklist.push_back(&frame.second); + } + } + + // Eliminate loops from innermost to outermost. + while (!worklist.empty()) { + Frame* frame = worklist.front(); + worklist.pop_front(); + if (frame->parent == frame) { + // Skip the root frame. + continue; + } + + TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + + // If the parent has no remaining children, add it to the worklist. + --frame->parent->num_children; + if (frame->parent->num_children == 0) { + worklist.push_back(frame->parent); + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h new file mode 100644 index 0000000000000000000000000000000000000000..1535dc80b0ccdba38c57b534ed7473fc8632e33f --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Transformation that converts tf.while_loop() loops into functional While +// operators, suitable for XLA compilation. +// TODO(b/36470387): add support for conditionals. +Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..914c8999a6f13f5f2dc4e3cecc38c91afd432131 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -0,0 +1,658 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { +namespace { + +// Returns the names of the "cond" and "body" functions for the While node +// in a graph. +Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, + NameAttrList* body) { + for (const NodeDef& node : graph.node()) { + if (node.op() == "XlaWhile") { + const NameAttrList* result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); + *cond = *result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result)); + *body = *result; + return Status::OK(); + } + } + return errors::NotFound("No XlaWhile node found in graph"); +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) +TEST(FunctionalizeControlFlow, OneLoopVar) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + // Add an unused Enter node. These should be ignored. + auto enter2 = + ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + // Regression test: control edges from an Enter node to the graph sink should + // be ignored. + for (Node* n : graph.nodes()) { + if (n->name() == "while/Enter") { + graph.AddControlEdge(n, graph.sink_node()); + } + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = array_ops.placeholder(dtypes.int32) +// cond = lambda (i, j): i + 3 < 10 +// body = lambda (i, j): (i < 10, j * 2) +// z = control_flow_ops.while_loop(cond, body, [x, y]) +TEST(FunctionalizeControlFlow, TwoLoopVars) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto enter_x = + ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop"); + auto enter_y = + ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop"); + auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"), + std::initializer_list{enter_x, dummy}); + auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"), + std::initializer_list{enter_y, dummy}); + + // Loop condition + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(merge_x.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") + .WithControlDependencies(merge_x.output), + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + + auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"), + merge_x.output, loop_cond); + auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"), + merge_y.output, loop_cond); + + auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"), + switch_x.output_false); + auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"), + switch_y.output_false); + + auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), + switch_x.output_true); + auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), + switch_y.output_true); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto next_iteration_x = + ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add); + auto next_iteration_y = + ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul); + + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y); + + // Remove the dummy node and add the loop backedges. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(), + 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); + auto ten = ops::Const( + scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + + auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0); + auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +// Example with nesting, loop-invariant arguments, and resource variables. +// +// accum = resource_variable_ops.ResourceVariable(1) +// x = array_ops.placeholder(2, dtype=dtypes.int32) +// y = 3 + x +// +// def inner_body(j, k): +// add = state_ops.assign_add(accum, k * j + x) +// with ops.control_dependencies([add]): +// return [j + 1, k] +// +// def body(i): +// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body, +// [1, y], name="inner") +// with ops.control_dependencies(m): +// return [i + 1] +// +// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer") +TEST(FunctionalizeControlFlow, Complex) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + // Outer loop + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + auto enter_i = + ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer"); + auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"), + std::initializer_list{enter_i, dummy}); + auto ten = ops::Const(scope.WithOpName("outer/Less/y") + .WithControlDependencies(merge_i.output), + 10); + auto less_i = + ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten); + auto outer_loop_cond = + ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i); + auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"), + merge_i.output, outer_loop_cond); + auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"), + switch_i.output_false); + auto identity_i = + ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true); + + auto enter_x_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_k_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_var_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + + // Inner loop + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"), + one_j, "inner"); + auto enter_k = + ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k") + .WithControlDependencies(identity_i), + enter_k_outer, "inner"); + auto enter_x = ops::internal::Enter( + scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_var = ops::internal::Enter( + scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner", + ops::internal::Enter::Attrs().IsConstant(true)); + + auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"), + std::initializer_list{enter_j, dummy}); + auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"), + std::initializer_list{enter_k, dummy}); + + auto five = ops::Const(scope.WithOpName("outer/inner/Five") + .WithControlDependencies(merge_j.output), + 5); + auto less_j = + ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five); + auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j); + + auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"), + merge_j.output, loop_cond); + auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"), + merge_k.output, loop_cond); + auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"), + switch_j.output_false); + auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"), + switch_k.output_false); + auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"), + switch_j.output_true); + auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"), + switch_k.output_true); + + // Variable update + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = + ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); + + auto one = + ops::Const(scope.WithOpName("outer/inner/One") + .WithControlDependencies( + gtl::ArraySlice{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto next_iteration_j = ops::NextIteration( + scope.WithOpName("outer/inner/NextIteration_j"), add_j); + auto next_iteration_k = ops::NextIteration( + scope.WithOpName("outer/inner/NextIteration_k"), identity_k); + + // Body and backedge for outer loop. + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(gtl::ArraySlice{ + exit_j.output.op(), exit_k.output.op()}), + identity_i, one_outer); + auto next_iteration_i = + ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit_i); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(), + 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList outer_cond_fn, outer_body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + + auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Outer condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto ten = ops::Const( + scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Outer body graph. + NameAttrList inner_cond_fn, inner_body_fn; + { + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); + + // Find the inner condition and body names. + TF_EXPECT_OK( + FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto while_op = + ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); + + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(gtl::ArraySlice{ + while_op[0].op(), while_op[1].op()}), + identity_i, one_outer); + + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Inner condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto five = ops::Const( + scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); + auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Inner body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_j = + ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); + auto identity_k = + ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + + auto one = + ops::Const(scope.WithOpName("outer/inner/One") + .WithControlDependencies( + gtl::ArraySlice{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); + auto retval1 = + ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index a434c7468095a05ee6da31826d44379a735b51f7..546e9be8647587991de5d0d0c232827ad84fba94 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -47,6 +47,7 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "segment_reduction_ops.cc", "select_op.cc", "sequence_ops.cc", "shape_op.cc", @@ -54,6 +55,7 @@ tf_kernel_library( "softmax_op.cc", "spacetobatch_op.cc", "split_op.cc", + "stack_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", @@ -68,6 +70,7 @@ tf_kernel_library( "reduction_ops.h", ], deps = [ + ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal_util", @@ -91,6 +94,21 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "while_op", + srcs = ["while_op.cc"], + hdrs = ["while_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + ], +) + # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. # diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 620fc8443785388781caf5121da53c4d908d4cb4..1156546512952871fafe93e4b5a42308322671df 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -51,13 +51,29 @@ class ArgOp : public XlaOpKernel { XlaContext& xc = XlaContext::Get(ctx); const XlaContext::Argument& arg = xc.args()[index_]; - if (arg.is_variable) { + if (arg.is_resource) { + XlaResource::Kind kind; + switch (arg.kind) { + case XlaCompiler::Argument::kVariable: + kind = XlaResource::kVariable; + break; + case XlaCompiler::Argument::kTensorArray: + kind = XlaResource::kTensorArray; + break; + case XlaCompiler::Argument::kStack: + kind = XlaResource::kStack; + break; + default: + CHECK(false); + } + // TODO(phawkins): this code assumes that variables do not alias. - XlaVariable* var; - OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type, - arg.value.handle, &var)); - var->tensor_array_size = arg.tensor_array_size; - ctx->SetVariableOutput(0, var); + XlaResource* resource; + OP_REQUIRES_OK(ctx, + xc.CreateResource(kind, index_, arg.name, arg.value.type, + arg.value.handle, &resource)); + resource->tensor_array_size = arg.tensor_array_size; + ctx->SetResourceOutput(0, resource); } else if (arg.value.is_constant) { ctx->SetConstantOutput(0, arg.value.constant_value); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 8642cbf2a924e3c82c80bff8f5122e62ce12082d..21d3e64872e19109852297838043975cea6d7921 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -127,8 +127,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, std::vector end_indices = reshaped_permuted_shape; std::vector strides(input_rank, 1); for (int i = 0; i < block_rank; ++i) { - int64 crop_start = xla::LiteralUtil::Get(crops, {i, 0}); - int64 crop_end = xla::LiteralUtil::Get(crops, {i, 1}); + int64 crop_start = crops.Get({i, 0}); + int64 crop_end = crops.Get({i, 1}); OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0, errors::InvalidArgument("Crops must be non-negative")); start_indices[1 + i] = crop_start; diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index b0fee5e4bca502a7abb4613b58ecdd2ffca2206d..bc2cd31230dfe9ca35540341d225dcb768fa34f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -55,7 +55,7 @@ class BCastGradArgsOp : public XlaOpKernel { BCast::Vec vec; for (int64 i = 0; i < in_shape.num_elements(); ++i) { - vec.push_back(xla::LiteralUtil::Get(literal, {i})); + vec.push_back(literal.Get({i})); } shapes.push_back(vec); } diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 124e33d7935ce19ced72d1c84521ffda1090bc86..2331520230176fce7646d89140851fe37aee5fda 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -38,17 +38,6 @@ class CastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; - } else if (src_dtype_ == DT_BOOL) { - // XLA's ConvertElementType doesn't support casting to/from - // bools. So we need to handle those cases separately. - // Builds the equivalent of (input ? 1 : 0) - xla::ComputationBuilder l(builder->client(), "PredCast"); - xla::ComputationDataHandle x = - l.Parameter(0, xla::ShapeUtil::MakeShape(src_type_, {}), "x"); - l.Select(x, XlaHelpers::One(&l, dst_dtype_), - XlaHelpers::Zero(&l, dst_dtype_)); - xla::Computation computation = l.Build().ConsumeValueOrDie(); - output = builder->Map({input}, computation); } else if (dst_dtype_ == DT_BOOL) { output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e2eacb3839d39e6fa41192e8aa0f31d878d96aea..73a4740e29af7fa57e71ef42a342f46b0e24231d 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -52,7 +52,7 @@ class ConcatBaseOp : public XlaOpKernel { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); // TODO(annarev): add a helper to support int64 input. - const int32 concat_dim = xla::LiteralUtil::Get(literal, {}); + const int32 concat_dim = literal.Get({}); std::vector values; std::vector shapes; @@ -163,7 +163,7 @@ class ConcatOffsetOp : public XlaOpKernel { xla::Literal concat_dim_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); - const int64 cdim = xla::LiteralUtil::Get(concat_dim_literal, {}); + const int64 cdim = concat_dim_literal.Get({}); VLOG(1) << "ConcatOffset " << cdim << "," << dims; int32 axis = cdim < 0 ? cdim + dims : cdim; @@ -185,12 +185,10 @@ class ConcatOffsetOp : public XlaOpKernel { for (int64 j = 0; j < dims; ++j) { if (j == axis) { out_vec(j) = offset; - offset += xla::LiteralUtil::Get(inp_literal, {j}); + offset += inp_literal.Get({j}); } else { - const int32 inp0_element = - xla::LiteralUtil::Get(inp0_literal, {j}); - const int32 inp_element = - xla::LiteralUtil::Get(inp_literal, {j}); + const int32 inp0_element = inp0_literal.Get({j}); + const int32 inp_element = inp_literal.Get({j}); OP_REQUIRES( ctx, (inp0_element == inp_element), errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index ad676e7a2bb3d3f28ecb98164323cbf1e32f61a9..9833323d851e00e7ca76d0b39cd2b216748a17fa 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 107c673f4a7d62f8b760b137aeda2864e156b7f7..dde7898015e73190c96fa6effddfd3fc892264ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -63,11 +63,14 @@ class DynamicStitchOp : public XlaOpKernel { std::vector indices(indices_input.size()); const TensorShape& data0_shape = data_shapes[0]; - const TensorShape indices0_shape = - XLAShapeToTensorShape(indices_input[0].shape()); + TensorShape indices0_shape; + OP_REQUIRES_OK( + ctx, XLAShapeToTensorShape(indices_input[0].shape(), &indices0_shape)); for (int input_num = 0; input_num < indices_input.size(); input_num++) { - const TensorShape indices_shape = - XLAShapeToTensorShape(indices_input[input_num].shape()); + TensorShape indices_shape; + OP_REQUIRES_OK(ctx, + XLAShapeToTensorShape(indices_input[input_num].shape(), + &indices_shape)); const TensorShape& data_shape = data_shapes[input_num]; OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), errors::InvalidArgument( @@ -103,8 +106,7 @@ class DynamicStitchOp : public XlaOpKernel { int max_index = -1; for (int input_num = 0; input_num < indices.size(); input_num++) { for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { - max_index = std::max( - max_index, xla::LiteralUtil::Get(indices[input_num], {i})); + max_index = std::max(max_index, indices[input_num].Get({i})); } } int number_of_indices = max_index + 1; @@ -118,7 +120,7 @@ class DynamicStitchOp : public XlaOpKernel { int index_used_count = 0; for (int input_num = 0; input_num < indices.size(); input_num++) { for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { - int index = xla::LiteralUtil::Get(indices[input_num], {i}); + int index = indices[input_num].Get({i}); src_input_vector[index] = input_num; src_slice_vector[index] = i; if (!src_index_used[index]) { diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 62a5e1bd421a75fb0a8fa6eacd58e4aaa2f02236..2fd27c5ca7e87c8b387d9d0854b787d30e7f7b6f 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -61,5 +61,49 @@ class EluGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Elu"), EluOp); REGISTER_XLA_OP(Name("EluGrad"), EluGradOp); +class SeluOp : public XlaOpKernel { + public: + explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), + 1.0507009873554804934193349852946); + const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), + 1.7580993408473768599402175208123); + const auto pred = b->Gt(ctx->Input(0), zero); + const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), + b->Mul(scale_alpha, expm1))); + } +}; + +class SeluGradOp : public XlaOpKernel { + public: + explicit SeluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return lhs * (1 + rhs). + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), + 1.0507009873554804934193349852946); + const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), + 1.7580993408473768599402175208123); + const auto grad = ctx->Input(0); + const auto activation = ctx->Input(1); + const auto lin_grad = b->Mul(grad, scale); + const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha)); + const auto pred = b->Gt(activation, zero); + ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad)); + } +}; + +REGISTER_XLA_OP(Name("Selu"), SeluOp); +REGISTER_XLA_OP(Name("SeluGrad"), SeluGradOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 1e1d2a1b4b3fa281adc96b76ade5ce7b07b2b41c..9e090fe01cbfd4dab81b0de21e3a44e42c2ef18e 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -52,7 +52,7 @@ class FillOp : public XlaOpKernel { std::vector broadcast; broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { - broadcast.push_back(xla::LiteralUtil::Get(dims_literal, {i})); + broadcast.push_back(dims_literal.Get({i})); } // Look up the value input, reshaping to a scalar if it was a // 'legacy' scalar (secretly a vector). diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index 8dacb6627bde516c92cb07b747207adbe85ada5b..af1085d5b35077b7ebd144bfb2473485e3b3de6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/node_def.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 49eadaf9d1f0ff1dbfa2321f20f9f833a0d4eb9a..184b5119f83d35e91d76685701c61fe712ac91ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -29,6 +29,7 @@ class GatherOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape params_shape = ctx->InputShape(0); + const auto params_dims = params_shape.dims(); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES( ctx, TensorShapeUtils::IsVectorOrHigher(params_shape), @@ -38,20 +39,51 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, errors::InvalidArgument("index must be int32 or int64")); + // GatherV2 added an axis argument. We support both Gather and GatherV2 in + // this kernel by defaulting axis to 0 if there are 2 inputs. + int64 axis = 0; + if (ctx->num_inputs() == 3) { + const TensorShape axis_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(axis_shape), + errors::InvalidArgument("axis must be scalar")); + DataType axis_type = input_type(2); + OP_REQUIRES(ctx, axis_type == DT_INT32 || axis_type == DT_INT64, + errors::InvalidArgument("axis must be int32 or int64")); + + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal)); + int64 axis_input = axis_type == DT_INT32 ? literal.Get({}) + : literal.Get({}); + axis = axis_input < 0 ? axis_input + params_dims : axis_input; + OP_REQUIRES(ctx, 0 <= axis && axis < params_dims, + errors::InvalidArgument("Expected axis in the range [", + -params_dims, ", ", params_dims, + "), but got ", axis_input)); + } + // Check that we have enough index space. const int64 limit = index_type == DT_INT32 ? std::numeric_limits::max() : std::numeric_limits::max(); - OP_REQUIRES( - ctx, params_shape.dim_size(0) <= limit, - errors::InvalidArgument("params.shape[0] too large for ", - DataTypeString(index_type), " indexing: ", - params_shape.dim_size(0), " > ", limit)); - - // The result shape is indices.shape + params.shape[1:]. - TensorShape result_shape = indices_shape; - for (int i = 1; i < params_shape.dims(); i++) { + OP_REQUIRES(ctx, params_shape.dim_size(axis) <= limit, + errors::InvalidArgument( + "params.shape[", axis, "] too large for ", + DataTypeString(index_type), + " indexing: ", params_shape.dim_size(axis), " > ", limit)); + + // The result shape is params.shape[0:axis] + indices.shape + + // params.shape[axis + 1:]. + TensorShape result_shape; + int64 outer_size = 1; + int64 inner_size = 1; + for (int i = 0; i < axis; i++) { + result_shape.AddDim(params_shape.dim_size(i)); + outer_size *= params_shape.dim_size(i); + } + result_shape.AppendShape(indices_shape); + for (int i = axis + 1; i < params_dims; i++) { result_shape.AddDim(params_shape.dim_size(i)); + inner_size *= params_shape.dim_size(i); } XlaContext& tc = XlaContext::Get(ctx); @@ -66,11 +98,13 @@ class GatherOp : public XlaOpKernel { std::vector args; args.push_back(tc.GetOrCreateRuntimeContextParameter()); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR0(indices_shape.num_elements()))); + *xla::Literal::CreateR0(indices_shape.num_elements()))); + args.push_back( + b.ConstantLiteral(*xla::Literal::CreateR0(outer_size))); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR0(params_shape.dim_size(0)))); - args.push_back(b.ConstantLiteral(*xla::LiteralUtil::CreateR0( - params_shape.num_elements() / params_shape.dim_size(0)))); + *xla::Literal::CreateR0(params_shape.dim_size(axis)))); + args.push_back( + b.ConstantLiteral(*xla::Literal::CreateR0(inner_size))); args.push_back(ctx->Input(0)); args.push_back(ctx->Input(1)); @@ -97,6 +131,10 @@ REGISTER_XLA_OP(Name("Gather") .TypeConstraint("Tparams", DT_FLOAT) .Device(DEVICE_CPU_XLA_JIT), GatherOp); +REGISTER_XLA_OP(Name("GatherV2") + .TypeConstraint("Tparams", DT_FLOAT) + .Device(DEVICE_CPU_XLA_JIT), + GatherOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc index 691a0b972d5c09ad632d706d72a1b60988730986..33b1b087d00d8263cd80f7d5d879401e4ed6c0fb 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc @@ -26,28 +26,31 @@ namespace tensorflow { EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 6 * sizeof(void*)); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 7 * sizeof(void*)); int64 indices_size = *static_cast(data[1]); int64 params_x = *static_cast(data[2]); int64 params_y = *static_cast(data[3]); + int64 params_z = *static_cast(data[4]); - float* in = static_cast(data[4]); + float* in = static_cast(data[5]); - int32* indices = static_cast(data[5]); - Eigen::DSizes in_eig_sizes; + int32* indices = static_cast(data[6]); + Eigen::DSizes in_eig_sizes; in_eig_sizes[0] = params_x; in_eig_sizes[1] = params_y; - tensorflow::TTypes::ConstMatrix in_eig(in, in_eig_sizes); + in_eig_sizes[2] = params_z; + tensorflow::TTypes::ConstTensor in_eig(in, in_eig_sizes); Eigen::DSizes indices_eig_sizes; indices_eig_sizes[0] = indices_size; tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); - Eigen::DSizes out_eig_sizes; - out_eig_sizes[0] = indices_size; - out_eig_sizes[1] = params_y; - tensorflow::TTypes::Matrix out_eig(out, out_eig_sizes); + Eigen::DSizes out_eig_sizes; + out_eig_sizes[0] = params_x; + out_eig_sizes[1] = indices_size; + out_eig_sizes[2] = params_z; + tensorflow::TTypes::Tensor out_eig(out, out_eig_sizes); tensorflow::functor::GatherFunctorCPU f; const int64 bad_i = f(in_eig, indices_eig, out_eig); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc index 3dff6e2737bf1af7f5d646928e740fa895692a03..5e2d872ce0b28ab479c73ed1fea5f32804c21e22 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc @@ -26,28 +26,31 @@ namespace tensorflow { EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 6 * sizeof(void*)); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 7 * sizeof(void*)); int64 indices_size = *static_cast(data[1]); int64 params_x = *static_cast(data[2]); int64 params_y = *static_cast(data[3]); + int64 params_z = *static_cast(data[4]); - float* in = static_cast(data[4]); + float* in = static_cast(data[5]); - int64* indices = static_cast(data[5]); - Eigen::DSizes in_eig_sizes; + int64* indices = static_cast(data[6]); + Eigen::DSizes in_eig_sizes; in_eig_sizes[0] = params_x; in_eig_sizes[1] = params_y; - tensorflow::TTypes::ConstMatrix in_eig(in, in_eig_sizes); + in_eig_sizes[2] = params_z; + tensorflow::TTypes::ConstTensor in_eig(in, in_eig_sizes); Eigen::DSizes indices_eig_sizes; indices_eig_sizes[0] = indices_size; tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); - Eigen::DSizes out_eig_sizes; - out_eig_sizes[0] = indices_size; - out_eig_sizes[1] = params_y; - tensorflow::TTypes::Matrix out_eig(out, out_eig_sizes); + Eigen::DSizes out_eig_sizes; + out_eig_sizes[0] = params_x; + out_eig_sizes[1] = indices_size; + out_eig_sizes[2] = params_z; + tensorflow::TTypes::Tensor out_eig(out, out_eig_sizes); tensorflow::functor::GatherFunctorCPU f; const int64 bad_i = f(in_eig, indices_eig, out_eig); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index df002dddd043c6795481436586a31c74b20d33d1..6be66cf66ec19cad33858f36a3239048efce9de3 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -69,7 +69,7 @@ class ArgMaxOp : public XlaOpKernel { // XLA op would have the same requirement. xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - const int32 dim = xla::LiteralUtil::Get(literal, {}); + const int32 dim = literal.Get({}); OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); OP_REQUIRES( ctx, dim < input_shape.dims(), @@ -97,14 +97,13 @@ class ArgMaxOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + *xla::Literal::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); - args.push_back( - b.ConstantLiteral(*xla::LiteralUtil::CreateR0(dim))); + *xla::Literal::CreateR1(output_shape.dim_sizes()))); + args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index b8f0c0b9fe6087a7719a689628ca4738cc13aab9..8c8a9bbe787f3224e7444b62dcf8ad99130cf37f 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -23,4 +23,9 @@ namespace tensorflow { // dummy operator using CompilationOnly(). REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); +// We register ControlTrigger as a no-op. This is correct since nodes seen +// by the XLA compiler are never dead. This may need rethinking when we add +// support for conditionals to XLA. +REGISTER_XLA_OP(Name("ControlTrigger"), NoOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 22476f4a0c51930cabf146313347e5e3bd2eaebe..d841bd37b33c31dbc156fa824ff62a58169a99cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -60,8 +60,8 @@ class PadOp : public XlaOpKernel { xla::PaddingConfig config; for (int i = 0; i < fixed_dims; ++i) { auto* dim = config.add_dimensions(); - int before = xla::LiteralUtil::Get(pad_literal, {i, 0}); - int after = xla::LiteralUtil::Get(pad_literal, {i, 1}); + int before = pad_literal.Get({i, 0}); + int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, errors::InvalidArgument("Paddings must be non-negative: ", before, " ", after)); @@ -69,12 +69,22 @@ class PadOp : public XlaOpKernel { dim->set_edge_padding_high(after); } - auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config)); + // PadV2 added a "constant_values" input that indicates the pad value. + xla::ComputationDataHandle constant_values; + if (ctx->num_inputs() == 3) { + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), + errors::InvalidArgument("constant_values must be a scalar.")); + ctx->SetOutput(0, + ctx->builder()->Pad(ctx->Input(0), ctx->Input(2), config)); + } else { + auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); + ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config)); + } } }; REGISTER_XLA_OP(Name("Pad"), PadOp); +REGISTER_XLA_OP(Name("PadV2"), PadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 518a9372c4fa3f195ff7c77e8ef0de1ba0a8807b..dae2eb9d2a92ef8d4eabb8d6f9a79758c42d446d 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -63,7 +63,7 @@ class MinOp : public XlaReductionOp { xla::ComputationBuilder* builder) override { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::LiteralUtil::MaxValue(type)); + return builder->ConstantLiteral(xla::Literal::MaxValue(type)); } void BuildReducer(xla::ComputationBuilder* builder, @@ -83,7 +83,7 @@ class MaxOp : public XlaReductionOp { xla::ComputationBuilder* builder) override { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::LiteralUtil::MinValue(type)); + return builder->ConstantLiteral(xla::Literal::MinValue(type)); } void BuildReducer(xla::ComputationBuilder* builder, diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 8798c80ad5354c76a9b4061ad8913b76ae0629b0..4b5d09eb9fd4110cdc4221099ff55767e9132540 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -66,13 +66,13 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { 1, {axes_tensor_shape.num_elements()}, &axes_literal)); VLOG(1) << "data shape: " << data_shape.DebugString(); - VLOG(1) << "axes : " << xla::LiteralUtil::ToString(axes_literal); + VLOG(1) << "axes : " << axes_literal.ToString(); gtl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { - int32 index = xla::LiteralUtil::Get(axes_literal, {i}); + int32 index = axes_literal.Get({i}); OP_REQUIRES(ctx, !(index < -data_shape.dims() || index >= data_shape.dims()), errors::InvalidArgument("Invalid reduction dimension (", index, diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index df542350b443b765a1ab35be9632cf61a38be49c..5952e752724d1e6953dd4dbb6a8099b847c64d08 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -50,7 +50,7 @@ class ReshapeOp : public XlaOpKernel { int64 product = 1; int unknown_index = -1; for (int d = 0; d < num_dims; ++d) { - const int32 size = xla::LiteralUtil::Get(literal, {d}); + const int32 size = literal.Get({d}); if (size == -1) { OP_REQUIRES( ctx, unknown_index == -1, diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a0ce775dc69e1b87041bad31b13cdaff676e20f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace { + +class UnsortedSegmentSum : public XlaOpKernel { + public: + explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // output = unsorted_segment_sum(data, indices, num_segments) + // Compute a tensor such that: + // output[i] = sum over {j where indices[j] == i} of data[j] + // output[i] == 0 if i does not appear in indices + // + // Contrast with segment_sum(), which assumes indices are sorted and that + // max(indices)+1 is the desired size of the output. + // + // The returned output tensor has the same type as data, and the same shape + // as data with the first indices.rank dimensions are replaced + // by a single dimension with size num_segments. + + xla::ComputationBuilder* builder = ctx->builder(); + + auto data = ctx->Input(0); + auto data_shape = ctx->InputShape(0); + + auto indices = ctx->Input(1); + auto indices_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(), + errors::InvalidArgument( + "UnsortedSegmentSum requires that indices' rank be" + " less than or equal to data's rank.")); + // Validate that indices.shape is a prefix of data.shape. + for (int d = 0; d < indices_shape.dims(); ++d) { + OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), + errors::InvalidArgument( + "UnsortedSegmentSum requires indices shape to be prefix" + " of data_shape, but dimension ", + d, " differs ", data_shape.dim_size(d), " vs. ", + indices_shape.dim_size(d))); + } + + int64 num_segments; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); + + // Flatten the indices into 1-D. + auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()}); + + // flatten data for dynamic indexing. + int64 out_tensor_dims = data_shape.dims() - indices_shape.dims(); + std::vector flat_shape(1 + out_tensor_dims); + flat_shape[0] = indices_shape.num_elements(); + for (int64 k = 0; k < out_tensor_dims; ++k) { + flat_shape[1 + k] = data_shape.dim_size(indices_shape.dims() + k); + } + auto data_flat = builder->Reshape(data, flat_shape); + + // output shape; same as data_shape, but dimension 0 is num_segments. + std::vector out_shape(flat_shape); + out_shape[0] = num_segments; + + // Pad the output array dims to rank >= 3 to work around lowering issues. + // TODO(b/37575001) This is awkward, and could be improved. + int64 extra_dims = 0; + if (out_shape.size() < 3) { + extra_dims = 3u - out_shape.size(); + } + std::vector rshape(extra_dims + out_shape.size(), 1); + for (unsigned k = 0; k < out_shape.size(); ++k) { + rshape[extra_dims + k] = out_shape[k]; + } + auto output = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), rshape); + + auto zero = builder->ConstantR1({0}); + + for (int64 i = 0; i < indices_shape.num_elements(); ++i) { + // output[indices[i]] += data[i] + + std::vector data_start_indices(flat_shape.size()); + data_start_indices[0] = i; + for (unsigned d = 1; d < flat_shape.size(); ++d) { + data_start_indices[d] = 0; + } + std::vector data_limit_indices(flat_shape); + data_limit_indices[0] = i + 1; + std::vector stride(flat_shape.size(), 1); + + auto data_slice = builder->Slice(data_flat, data_start_indices, + data_limit_indices, stride); + + // Reshape the sliced data into the R3+ shape to match output array. + std::vector rdata_shape(extra_dims + flat_shape.size()); + for (int64 k = 0; k <= extra_dims; ++k) { + rdata_shape[k] = 1; + } + for (unsigned k = 1; k < data_limit_indices.size(); ++k) { + rdata_shape[extra_dims + k] = data_limit_indices[k]; + } + auto rdata_slice = builder->Reshape(data_slice, rdata_shape); + + auto index = builder->Slice(indices_1d, {i}, {i + 1}, {1}); + + // Construct the index into the R3+ output array 0, ..., , 0, ... + std::vector out_start_index_parts( + extra_dims + flat_shape.size(), zero); + out_start_index_parts[extra_dims] = builder->Reshape(index, {1}); + auto out_start_indices = builder->ConcatInDim(out_start_index_parts, 0); + + std::vector slice_size(rshape); + slice_size[extra_dims] = 1; + + auto out_slice = + builder->DynamicSlice(output, out_start_indices, slice_size); + auto sumval = builder->Add(out_slice, rdata_slice); + output = builder->DynamicUpdateSlice(output, sumval, out_start_indices); + } + auto reshaped_output = builder->Reshape(output, out_shape); + ctx->SetOutput(0, reshaped_output); + } + + private: + DataType dtype_; +}; + +REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 5b6fa64fa825894b5d7bf938c5892d30f4fc11b0..c2b0e1bb4c1a141d0ab3f5b3ff5397d9da620bd8 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -32,7 +32,7 @@ template Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { xla::Literal literal; TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); return Status::OK(); } @@ -41,10 +41,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); switch (literal.shape().element_type()) { case xla::S32: - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); break; case xla::S64: - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); break; default: return errors::InvalidArgument("Invalid argument type for argument", @@ -58,9 +58,9 @@ template Status CreateRangeTensor(const xla::Literal& start_literal, const xla::Literal& limit_literal, const xla::Literal& delta_literal, Tensor* output) { - T start = xla::LiteralUtil::Get(start_literal, {}); - T limit = xla::LiteralUtil::Get(limit_literal, {}); - T delta = xla::LiteralUtil::Get(delta_literal, {}); + T start = start_literal.Get({}); + T limit = limit_literal.Get({}); + T delta = delta_literal.Get({}); if (delta == 0) { return errors::InvalidArgument("Requires delta != 0: ", delta); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index f15b354cb26d390352d866a8e827970f7c8b0c7f..83a87f19a718ce86a105e3c33ab9eaf0faff3a76 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -56,8 +56,8 @@ void SpaceToBatch(XlaOpKernelContext* ctx, padding_config.add_dimensions(); // Don't pad the batch dimension. for (int i = 0; i < block_rank; ++i) { auto* dim = padding_config.add_dimensions(); - int64 pad_start = xla::LiteralUtil::Get(paddings, {i, 0}); - int64 pad_end = xla::LiteralUtil::Get(paddings, {i, 1}); + int64 pad_start = paddings.Get({i, 0}); + int64 pad_end = paddings.Get({i, 1}); OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0, errors::InvalidArgument("Paddings must be non-negative")); dim->set_edge_padding_low(pad_start); diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 42bde90042218b3a36f50e32d4f458d31c82d5da..44ee81461e5b31f15594c0dfb86f7219f9875768 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -39,7 +39,7 @@ class SplitOp : public XlaOpKernel { int32 split_dim; if (index_shape.dims() == 0) { - split_dim = xla::LiteralUtil::Get(literal_index, {}); + split_dim = literal_index.Get({}); } else { OP_REQUIRES( ctx, index_shape.dims() == 1, @@ -49,7 +49,7 @@ class SplitOp : public XlaOpKernel { ctx, index_shape.dim_size(0) == 1, errors::InvalidArgument("split_index input to Split Op must be a " "scalar or a vector with 1 element")); - split_dim = xla::LiteralUtil::Get(literal_index, {0}); + split_dim = literal_index.Get({0}); } const int32 num_split = num_outputs(); const TensorShape input_shape = ctx->InputShape(1); @@ -115,7 +115,7 @@ class SplitVOp : public XlaOpKernel { OP_REQUIRES(ctx, index_shape.dims() == 0, errors::InvalidArgument("split_dim input to Split Op must be a " "scalar")); - split_dim = xla::LiteralUtil::Get(literal_index, {}); + split_dim = literal_index.Get({}); xla::ComputationDataHandle input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); @@ -152,7 +152,7 @@ class SplitVOp : public XlaOpKernel { for (int i = 0; i < num_split; ++i) { int slice_size; - slice_size = xla::LiteralUtil::Get(split_size_literal, {i}); + slice_size = split_size_literal.Get({i}); if (slice_size == -1) { OP_REQUIRES( ctx, neg_one_dim == -1, diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..5d1394c280383b7e9b9be39da4ed028e15a005fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -0,0 +1,250 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA Stack operators. + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, + TensorShape* stack_shape) { + auto shape_or_status = builder->GetShape(resource->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + xla::Shape shape = *shape_or_status.ValueOrDie(); + TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), + stack_shape); +} + +// Since the element shape is not provided to the Stack operator, +// we lazily initialize the Stack at the time of the first write. +// +// If a Stack `resource` has not been initialized, constructs storage for the +// Stack with elements of `elem_shape`. For both initialized and +// uninitialized Stacks, checks that the tensor has a type compatible with +// 'dtype' and shape compatible with 'elem_shape'. +// +// TODO(phawkins): consider changing the API of the stack operators to +// allow an optional element shape at stack construction time. +Status MaybeInitializeStack(xla::ComputationBuilder* builder, + XlaResource* resource, DataType dtype, + const TensorShape& elem_shape) { + if (resource->type != dtype) { + return errors::InvalidArgument( + "Stack dtype is ", DataTypeString(resource->type), " but op has dtype ", + DataTypeString(dtype), "."); + } + + TensorShape stack_shape; + stack_shape.AddDim(resource->tensor_array_size); + stack_shape.AppendShape(elem_shape); + + if (resource->value.handle() == 0) { + // Stack has not been initialized. + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); + resource->value = + builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()), + builder->ConstantR0(0)}); + } else { + // Checks the expected shape matches the actual shape. + TensorShape actual_shape; + TF_RETURN_IF_ERROR(GetStackShape(builder, resource, &actual_shape)); + if (stack_shape != actual_shape) { + return errors::InvalidArgument( + "Mismatched Stack shapes: ", stack_shape.DebugString(), " vs ", + actual_shape.DebugString()); + } + } + return Status::OK(); +} + +class StackOp : public XlaOpKernel { + public: + explicit StackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("stack_name", &stack_name_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + int64 size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size)); + OP_REQUIRES( + ctx, size >= 0, + errors::InvalidArgument( + "XLA compilation requires a fixed stack size upper bound.")); + + // We defer initializing the Stack resource until we see the first push. + // Otherwise we do not know the shape of the stack elements. + xla::ComputationDataHandle value; + XlaContext& xc = XlaContext::Get(ctx); + XlaResource* resource; + string name = strings::StrCat("Stack: ", stack_name_); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, + value, &resource)); + resource->tensor_array_size = size; + ctx->SetResourceOutput(0, resource); + } + + private: + DataType dtype_; + string stack_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(StackOp); +}; + +REGISTER_XLA_OP(Name("StackV2"), StackOp); + +class StackPushOp : public XlaOpKernel { + public: + explicit StackPushOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + TensorShape elem_shape = ctx->InputShape(1); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + // Initializes the Stack, if the element shape was not already known. + OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); + + xla::ComputationDataHandle ta = b->GetTupleElement(resource->value, 0); + xla::ComputationDataHandle index = b->GetTupleElement(resource->value, 1); + xla::ComputationDataHandle value = ctx->Input(1); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = XlaHelpers::PadWithZeros(b, index, elem_shape.dims()); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = b->Reshape(value, slice_shape.dim_sizes()); + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + resource->value = + b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices), + b->Add(index, b->ConstantR0(1))}); + + ctx->SetOutput(0, value); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp); +}; + +REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp); + +class StackPopOp : public XlaOpKernel { + public: + explicit StackPopOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES(ctx, resource->type == dtype_, + errors::InvalidArgument( + "Stack dtype is ", DataTypeString(resource->type), + " but Op requested dtype ", DataTypeString(dtype_), ".")); + + // There is a somewhat subtle issue here: here "uninitialized" means we have + // not yet seen a pop in the order that we compile operators, not the order + // that we run them. However, in practice the two orders should be the same + // for the sole user of the stack operators (loop gradients). + OP_REQUIRES(ctx, resource->value.handle() != 0, + errors::InvalidArgument("Stack pop on uninitialized stack")); + + TensorShape stack_shape; + OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); + + xla::ComputationDataHandle state = resource->value; + xla::ComputationDataHandle ta = b->GetTupleElement(state, 0); + xla::ComputationDataHandle index = b->GetTupleElement(state, 1); + + index = b->Sub(index, b->ConstantR0(1)); + resource->value = b->Tuple({ta, index}); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = + XlaHelpers::PadWithZeros(b, index, stack_shape.dims() - 1); + + auto slice_shape = stack_shape.dim_sizes(); + slice_shape[0] = 1LL; + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + xla::ComputationDataHandle read = + b->DynamicSlice(ta, start_indices, slice_shape); + + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + ctx->SetOutput(0, b->Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp); +}; + +REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp); + +class StackCloseOp : public XlaOpKernel { + public: + explicit StackCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // Do nothing. + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp); +}; + +REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 9eb689983105eff05555bbe454f97149eb8f14a2..6af4bd0496e0da926726e3f74376281f539e925a 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -63,17 +63,13 @@ class StridedSliceOp : public XlaOpKernel { &strides_tensor)); TensorShape dummy_processing_shape; - ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); - ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape( - &dummy_processing_shape); bool dummy = false; - OP_REQUIRES_OK( - ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, - ShapeReadWriteFromTensorShape(&input_shape), begin_mask_, - end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, - &dummy, &dummy, &begin, &end, &strides)); + OP_REQUIRES_OK(ctx, + ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, input_shape, + begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &dummy_processing_shape, &final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides)); gtl::InlinedVector dimensions_to_reverse; gtl::InlinedVector slice_begin, slice_end, slice_strides; @@ -146,14 +142,11 @@ class StridedSliceGradOp : public XlaOpKernel { &strides_tensor)); bool dummy = false; - ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); - ShapeReadWriteFromTensorShape wrapped_processing_shape(&processing_shape); OP_REQUIRES_OK( ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, - ShapeReadWriteFromTensorShape(&input_shape), begin_mask_, - end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &wrapped_processing_shape, &wrapped_final_shape, &dummy, + &begin_tensor, &end_tensor, strides_tensor, input_shape, + begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &processing_shape, &final_shape, &dummy, &dummy, &dummy, &begin, &end, &strides)); // Check to make sure dy is consistent with the original slice @@ -257,17 +250,13 @@ class StridedSliceAssignOp : public XlaOpKernel { const TensorShape rhs_shape = ctx->InputShape(4); TensorShape dummy_processing_shape; - ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); - ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape( - &dummy_processing_shape); bool dummy = false; - OP_REQUIRES_OK( - ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, - ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_, - end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, - &dummy, &dummy, &begin, &end, &strides)); + OP_REQUIRES_OK(ctx, + ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, lhs_shape, + begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &dummy_processing_shape, &final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides)); if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { // DynamicUpdateSlice does not allow 0-element updates. We should probably diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index deee7dd44dbf80f83ded3f09819365f7b6c1c7bd..34cc8b23159a0c20166c28d21911d4f3e7a43693 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -41,36 +41,42 @@ namespace { // Since the element shape is not always provided to the TensorArrayV3 operator, // we must support lazily initialization of the TensorArray at the time of the // first write. -// If a TensorArray `var` has not been initialized, constructs storage for the -// TensorArray with elements of `elem_shape`. For both initialized and +// If a TensorArray `resource` has not been initialized, constructs storage for +// the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, - XlaVariable* var, DataType dtype, + XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (var->type != dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument("Unexpected non-TensorArray resource"); + } + + if (resource->type != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(var->type), + "TensorArray dtype is ", DataTypeString(resource->type), " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(var->tensor_array_size >= 0) - << var->name << " size " << var->tensor_array_size; + TF_RET_CHECK(resource->tensor_array_size >= 0) + << resource->name << " size " << resource->tensor_array_size; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - if (var->value.handle() == 0) { + if (resource->value.handle() == 0) { // TensorArray has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type); - var->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); + resource->value = builder->Broadcast(zero, ta_shape.dim_sizes()); } else { // Checks the elem_shape matches the TensorArray shape. - auto shape_or_status = builder->GetShape(var->value); + auto shape_or_status = builder->GetShape(resource->value); if (!shape_or_status.ok()) { return shape_or_status.status(); } - TensorShape shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + TensorShape shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); if (ta_shape != shape) { return errors::InvalidArgument( "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", @@ -80,14 +86,43 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, return Status::OK(); } -// Pads 'x' with 'count' zero indices. 'x' must have 1 element. -xla::ComputationDataHandle PadIndexWithZeros( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - int count) { - xla::ComputationDataHandle zero = builder->ConstantR1({0}); - std::vector xs(count + 1, zero); - xs[0] = builder->Reshape(x, {1}); - return builder->ConcatInDim(xs, 0); +// Checks that the TensorArray 'resource' has been initialized, and has type +// 'dtype'. Sets 'shape' to the shape +Status CheckTensorArrayIsInitialized(const string& op_name, + const XlaResource* resource, + DataType dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument( + "Unexpected non-TensorArray resource passed " + "to ", + op_name); + } + if (resource->value.handle() == 0) { + return errors::InvalidArgument("Uninitialized TensorArray passed to ", + op_name); + } + if (resource->type != dtype) { + return errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(resource->type), + " but op has dtype ", DataTypeString(dtype), "."); + } + + return Status::OK(); +} + +Status GetTensorArrayShape(const XlaResource* resource, + xla::ComputationBuilder* builder, + TensorShape* shape) { + auto shape_or_status = builder->GetShape(resource->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); + if (shape->dims() < 1) { + return errors::InvalidArgument("TensorArray rank must be >= 1"); + } + return Status::OK(); } // Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the @@ -125,7 +160,6 @@ class TensorArrayOp : public XlaOpKernel { errors::InvalidArgument("TensorArray size must be >= 0")); xla::ComputationBuilder* b = ctx->builder(); - b->set_die_immediately_on_error(true); // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. @@ -141,13 +175,17 @@ class TensorArrayOp : public XlaOpKernel { } XlaContext& xc = XlaContext::Get(ctx); - XlaVariable* var; + XlaResource* var; string name = strings::StrCat("TensorArray: ", tensor_array_name_); - OP_REQUIRES_OK(ctx, - xc.CreateVariable(-1, std::move(name), dtype_, value, &var)); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + dtype_, value, &var)); var->tensor_array_size = size; - ctx->SetVariableOutput(0, var); - ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); + ctx->SetResourceOutput(0, var); + + Tensor flow(DT_FLOAT, TensorShape({})); + flow.scalar()() = 0.0f; + ctx->SetConstantOutput(1, flow); } private: @@ -173,16 +211,18 @@ class TensorArrayWriteOp : public XlaOpKernel { // Initializes the TensorArray, if the element shape was not known at // construction time. - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); + xla::ComputationDataHandle flow = ctx->Input(3); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + auto start_indices = XlaHelpers::PadWithZeros(b, index, elem_shape.dims()); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -191,8 +231,8 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written)); - ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + resource->value = written; + ctx->SetOutput(0, flow); } private: @@ -210,24 +250,22 @@ class TensorArrayReadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(ta_type), - " but Op requested dtype ", DataTypeString(dtype_), ".")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); + auto start_indices = + XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1); auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; @@ -255,24 +293,23 @@ class TensorArrayGatherOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; + xla::ComputationBuilder* b = ctx->builder(); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); const TensorShape indices_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, indices_shape.dims() >= 1, + OP_REQUIRES(ctx, indices_shape.dims() == 1, errors::InvalidArgument("indices must be rank 1")); const int num_indices = indices_shape.dim_size(0); auto indices = ctx->Input(1); - xla::ComputationBuilder* b = ctx->builder(); - - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + xla::ComputationDataHandle ta = resource->value; // For each index in `indices`, add the corresponding slice to `slices`. std::vector slices(num_indices); @@ -282,7 +319,8 @@ class TensorArrayGatherOp : public XlaOpKernel { auto index = b->Slice(indices, {i}, {i + 1}, {1}); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); + auto start_indices = + XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1); auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; @@ -320,11 +358,12 @@ class TensorArrayScatterOp : public XlaOpKernel { const TensorShape value_shape = ctx->InputShape(2); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); TensorShape elem_shape = value_shape; elem_shape.RemoveDim(0); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -332,8 +371,9 @@ class TensorArrayScatterOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); const xla::ComputationDataHandle indices = ctx->Input(1); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; const xla::ComputationDataHandle value = ctx->Input(2); + const xla::ComputationDataHandle flow = ctx->Input(3); auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -353,12 +393,13 @@ class TensorArrayScatterOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto index = b->Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + auto start_indices = + XlaHelpers::PadWithZeros(b, index, elem_shape.dims()); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); - ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + resource->value = ta; + ctx->SetOutput(0, flow); } private: @@ -376,18 +417,17 @@ class TensorArrayConcatOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); @@ -438,31 +478,32 @@ class TensorArraySplitOp : public XlaOpKernel { elem_shape.set_dim(0, length); xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); + xla::ComputationDataHandle ta = resource->value; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size, + OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size, errors::InvalidArgument( "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", var->tensor_array_size, ")")); + lengths.size(), " vs. ", resource->tensor_array_size, ")")); const xla::ComputationDataHandle value = ctx->Input(1); + const xla::ComputationDataHandle flow = ctx->Input(3); OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(), errors::InvalidArgument("mismatched element count ", value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); - ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + ctx->SetOutput(0, flow); } private: @@ -478,8 +519,8 @@ class TensorArraySizeOp : public XlaOpKernel { explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* var; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); size_tensor.scalar()() = static_cast(var->tensor_array_size); ctx->SetConstantOutput(0, size_tensor); @@ -500,31 +541,31 @@ class TensorArrayGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - DataType ta_type; + OP_REQUIRES_OK( + ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); // Finds or looks up the corresponding gradient TensorArray, which stores // gradients computed during backpropagation. - XlaVariable*& gradient = var->tensor_array_gradient[source_]; + XlaResource*& gradient = resource->tensor_array_gradient[source_]; if (!gradient) { - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type); xla::ComputationDataHandle value = b->Broadcast(zero, ta_shape.dim_sizes()); XlaContext& xc = XlaContext::Get(ctx); - string name = strings::StrCat("TensorArrayGrad: ", var->name); - OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type, - value, &gradient)); - gradient->tensor_array_size = var->tensor_array_size; + string name = strings::StrCat("TensorArrayGrad: ", resource->name); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + resource->type, value, &gradient)); + gradient->tensor_array_size = resource->tensor_array_size; } - ctx->SetVariableOutput(0, gradient); + ctx->SetResourceOutput(0, gradient); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } @@ -536,5 +577,19 @@ class TensorArrayGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp); +class TensorArrayCloseOp : public XlaOpKernel { + public: + explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // Do nothing; XLA handles resource management. + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 4cc2eb8f877a873593f0460346e3379e851e8e08..9ee6bd892504e683a191484fb09259619759f36d 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -68,7 +68,7 @@ class TileOp : public XlaOpKernel { bool all_multiples_are_one = true; bool one_dimension_is_broadcasted_without_multiple = true; for (int i = 0; i < input_dims; ++i) { - int multiple = xla::LiteralUtil::Get(literal, {i}); + int multiple = literal.Get({i}); OP_REQUIRES(ctx, multiple, errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", multiple)); diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index e9ac1ee91b8e86a7154f42b8c51dcbb5c8a32a83..a2ecbca124c28574560afea17e13889506869e36 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -352,9 +352,9 @@ class ResourceApplyRMSProp : public XlaOpKernel { b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); xla::ComputationDataHandle new_mom = b->Add(b->Mul(mom, momentum), - b->Div(b->Mul(grad, lr), + b->Mul(b->Mul(grad, lr), b->Pow(b->Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, 0.5)))); + XlaHelpers::FloatLiteral(b, type, -0.5)))); xla::ComputationDataHandle new_var = b->Sub(var, new_mom); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); @@ -364,112 +364,160 @@ class ResourceApplyRMSProp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp); -class ResourceApplyFtrl : public XlaOpKernel { - public: - explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); +void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, + bool has_l2_shrinkage) { + xla::ComputationBuilder* b = ctx->builder(); + + DataType var_type, accum_type, linear_type; + TensorShape var_shape, accum_shape, linear_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); + + OP_REQUIRES( + ctx, dtype == var_type && dtype == accum_type && dtype == linear_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyFtrlV2 must match: ", + DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ", + DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape), + errors::InvalidArgument( + "var and linear do not have the same shape", + var_shape.DebugString(), " ", linear_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape l1_shape = ctx->InputShape(5); + TensorShape l2_shape = ctx->InputShape(6); + TensorShape l2_shrinkage_shape; + TensorShape lr_power_shape; + if (has_l2_shrinkage) { + l2_shrinkage_shape = ctx->InputShape(7); + lr_power_shape = ctx->InputShape(8); + } else { + lr_power_shape = ctx->InputShape(7); } - void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument("var and grad do not have the same shape", + var_shape.DebugString(), " ", + grad_shape.DebugString())); - DataType var_type, accum_type, linear_type; - TensorShape var_shape, accum_shape, linear_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - OP_REQUIRES_OK( - ctx, ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString())); - OP_REQUIRES( - ctx, - dtype_ == var_type && dtype_ == accum_type && dtype_ == linear_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyFtrl must match: ", - DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", l1_shape.DebugString())); - OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), - errors::InvalidArgument( - "var and accum do not have the same shape", - var_shape.DebugString(), " ", accum_shape.DebugString())); - - OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape), - errors::InvalidArgument( - "var and linear do not have the same shape", - var_shape.DebugString(), " ", linear_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", l2_shape.DebugString())); - TensorShape grad_shape = ctx->InputShape(3); - TensorShape lr_shape = ctx->InputShape(4); - TensorShape l1_shape = ctx->InputShape(5); - TensorShape l2_shape = ctx->InputShape(6); - TensorShape lr_power_shape = ctx->InputShape(7); + if (has_l2_shrinkage) { + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shrinkage_shape), + errors::InvalidArgument("l2_shrinkage is not a scalar: ", + l2_shrinkage_shape.DebugString())); + } - OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), - errors::InvalidArgument( - "var and grad do not have the same shape", - var_shape.DebugString(), " ", grad_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape), + errors::InvalidArgument("lr_power is not a scalar: ", + lr_power_shape.DebugString())); + + xla::ComputationDataHandle var, accum, linear; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); + xla::ComputationDataHandle grad = ctx->Input(3); + xla::ComputationDataHandle lr = ctx->Input(4); + xla::ComputationDataHandle l1 = ctx->Input(5); + xla::ComputationDataHandle l2 = ctx->Input(6); + xla::ComputationDataHandle l2_shrinkage; + xla::ComputationDataHandle lr_power; + if (has_l2_shrinkage) { + l2_shrinkage = ctx->Input(7); + lr_power = ctx->Input(8); + } else { + lr_power = ctx->Input(7); + } - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), - errors::InvalidArgument("lr is not a scalar: ", - lr_shape.DebugString())); + // grad_to_use = grad + 2 * l2_shrinkage * var + // new_accum = accum + grad_to_use * grad_to_use + // linear += grad_to_use - + // (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var + // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 + // var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 + // accum = new_accum + + xla::ComputationDataHandle zero_broadcast = b->Broadcast( + XlaHelpers::FloatLiteral(b, dtype, 0.0), var_shape.dim_sizes()); + xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + xla::ComputationDataHandle grad_to_use; + if (has_l2_shrinkage) { + grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); + } else { + grad_to_use = grad; + } - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), - errors::InvalidArgument("l1 is not a scalar: ", - l1_shape.DebugString())); + xla::ComputationDataHandle new_accum = + b->Add(accum, b->Pow(grad_to_use, two)); + xla::ComputationDataHandle new_accum_lr_pow = + b->Pow(new_accum, b->Neg(lr_power)); + xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); + linear = b->Add( + linear, + b->Sub(grad_to_use, + b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); + xla::ComputationDataHandle quadratic = + b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); + xla::ComputationDataHandle pre_shrink = + b->Div(b->Sub(b->Mul(l1, b->Sign(linear)), linear), quadratic); + var = b->Select(b->Gt(b->Abs(linear), l1), pre_shrink, zero_broadcast); + accum = new_accum; + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype, linear)); +} - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), - errors::InvalidArgument("l2 is not a scalar: ", - l2_shape.DebugString())); +class ResourceApplyFtrl : public XlaOpKernel { + public: + explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape), - errors::InvalidArgument("lr_power is not a scalar: ", - lr_power_shape.DebugString())); + void Compile(XlaOpKernelContext* ctx) override { + CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/false); + } - xla::ComputationDataHandle var, accum, linear; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); - xla::ComputationDataHandle grad = ctx->Input(3); - xla::ComputationDataHandle lr = ctx->Input(4); - xla::ComputationDataHandle l1 = ctx->Input(5); - xla::ComputationDataHandle l2 = ctx->Input(6); - xla::ComputationDataHandle lr_power = ctx->Input(7); - - // new_accum = accum + grad * grad - // linear += grad - (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var - // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 - // var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 - // accum = new_accum - - xla::ComputationDataHandle zero_broadcast = b->Broadcast( - XlaHelpers::FloatLiteral(b, dtype_, 0.0), var_shape.dim_sizes()); - xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl); - xla::ComputationDataHandle new_accum = b->Add(accum, b->Pow(grad, two)); - xla::ComputationDataHandle new_accum_lr_pow = - b->Pow(new_accum, b->Neg(lr_power)); - xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); - linear = b->Add( - linear, - b->Sub(grad, b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), - var))); - xla::ComputationDataHandle quadratic = - b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); - xla::ComputationDataHandle pre_shrink = - b->Div(b->Sub(b->Mul(l1, b->Sign(linear)), linear), quadratic); - var = b->Select(b->Gt(b->Abs(linear), l1), pre_shrink, zero_broadcast); - accum = new_accum; +class ResourceApplyFtrlV2 : public XlaOpKernel { + public: + explicit ResourceApplyFtrlV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, linear)); + void Compile(XlaOpKernelContext* ctx) override { + CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/true); } private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl); +REGISTER_XLA_OP(Name("ResourceApplyFtrlV2"), ResourceApplyFtrlV2); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index abe4949f5dbc8034fa46828e3ff872cae7591d90..626ddd17d394d4a2e1c014c3a280949a415dce94 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -44,6 +44,8 @@ namespace { // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); +XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); +XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. @@ -77,12 +79,19 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, b->LogicalAnd(b->Eq(fraction, half), is_odd)), b->Add(round_val, one), round_val); } -XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); +// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. +static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b, + DataType dtype, + const xla::ComputationDataHandle& x) { + auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); + return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); +} + +XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); XLAJIT_MAKE_UNARY(Rsqrt, b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); -XLAJIT_MAKE_UNARY(Sigmoid, - b->Map({x}, *ctx->GetOrCreateSigmoid(input_type(0)))); +XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); XLAJIT_MAKE_UNARY(Softplus, b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0))))); XLAJIT_MAKE_UNARY(Sqrt, diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 1b04b8b802c5c6e9da337933a7c4cd99233ebe8d..0eea81b308bbee26cc607bcb95ffbf2d3f6abe0f 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/computation_builder.h" @@ -86,5 +87,69 @@ REGISTER_XLA_OP( Name("AssignSubVariableOp").TypeConstraint("dtype", kNumericTypes), AssignSubVariableOp); +class ResourceGatherOp : public XlaOpKernel { + public: + explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + // Get the shape of the resource tensor. + TensorShape resource_shape; + DataType resource_dtype; + OP_REQUIRES_OK( + ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape)); + + DataType expected_output_dtype = ctx->expected_output_dtype(0); + OP_REQUIRES(ctx, resource_dtype == expected_output_dtype, + errors::InvalidArgument( + "Variable dtype is ", DataTypeString(resource_dtype), + " but expected output dtype is ", + DataTypeString(expected_output_dtype), ".")); + + xla::ComputationDataHandle resource_handle; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle)); + + auto indices = ctx->Input(1); + auto indices_shape = ctx->InputShape(1); + const int num_indices = indices_shape.num_elements(); + + // Flatten the indices into 1-D. + auto indices_1d = builder->Reshape(indices, {num_indices}); + + // Compute the slice for each of these indices separately. + std::vector slices(num_indices); + for (int i = 0; i < num_indices; ++i) { + auto index = builder->Slice(indices_1d, {i}, {i + 1}, {1}); + + auto start_indices = + XlaHelpers::PadWithZeros(builder, index, resource_shape.dims() - 1); + + auto slice_shape = resource_shape.dim_sizes(); + slice_shape[0] = 1LL; + + slices[i] = + builder->DynamicSlice(resource_handle, start_indices, slice_shape); + } + + // Concatenate the slices into one tensor. + xla::ComputationDataHandle concat = builder->ConcatInDim(slices, 0); + + // Compute the shape of the result tensor, which is: + // indices.shape + resource.shape[1:] + TensorShape gather_shape = indices_shape; + gather_shape.AppendShape(resource_shape); + gather_shape.RemoveDim(indices_shape.dims()); + + // Reshape the concatenated slices into the shape expected of the result + // tensor. + xla::ComputationDataHandle gather = + builder->Reshape(concat, gather_shape.dim_sizes()); + + ctx->SetOutput(0, gather); + } +}; +REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), + ResourceGatherOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c2031fc761e55ddb08a19dbc1b34a4d60e19562 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -0,0 +1,277 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace { + +// Builds XlaCompiler argument descriptions `args` from `ctx`. +Status MakeXlaCompilerArgumentsFromInputs( + XlaOpKernelContext* ctx, std::vector* args, + bool* has_uninitialized_vars) { + VLOG(2) << "Num inputs " << ctx->num_inputs(); + args->resize(ctx->num_inputs()); + *has_uninitialized_vars = false; + for (int i = 0; i < ctx->num_inputs(); ++i) { + VLOG(2) << " Input " << i + << " type: " << DataTypeString(ctx->input_type(i)) + << " shape: " << ctx->InputShape(i).DebugString(); + XlaCompiler::Argument& arg = (*args)[i]; + DataType type = ctx->input_type(i); + // When reading a resource input, use the type and shape of the resource's + // current value. + if (type == DT_RESOURCE) { + XlaResource* resource; + TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource)); + + arg.initialized = resource->value.handle() > 0; + switch (resource->kind) { + case XlaResource::kVariable: + arg.kind = XlaCompiler::Argument::kVariable; + break; + case XlaResource::kTensorArray: + arg.kind = XlaCompiler::Argument::kTensorArray; + break; + case XlaResource::kStack: + arg.kind = XlaCompiler::Argument::kStack; + break; + case XlaResource::kInvalid: + CHECK(false); + } + arg.type = resource->type; + if (arg.initialized) { + auto shape = ctx->builder()->GetShape(resource->value); + TF_RETURN_IF_ERROR(shape.status()); + arg.shape = *shape.ValueOrDie(); + } else { + *has_uninitialized_vars = true; + } + arg.tensor_array_size = resource->tensor_array_size; + arg.name = resource->name; + // TODO(phawkins): propagate TensorArray gradients into loops. + VLOG(2) << " resource " << resource->name + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString() + << " initialized: " << arg.initialized; + + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = ctx->input_type(i); + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + } + } + return Status::OK(); +} + +} // anonymous namespace + +XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr)); + cond_name_attr_ = *name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); + body_name_attr_ = *name_attr; +} + +void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { + VLOG(1) << "WhileOp::Compile"; + + std::vector arguments; + bool has_uninitialized_vars; + OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs( + ctx, &arguments, &has_uninitialized_vars)); + + const bool use_tuple_arg = (arguments.size() != 1); + + xla::ComputationBuilder* builder = ctx->builder(); + XlaCompiler* compiler = ctx->compiler(); + + VLOG(1) << "Compiling body"; + + // All resource that are inputs to the loop's body must also be + // present as loop body outputs; the signature of the loop's input and + // output must match. We ensure this by asking the compiler to include the + // current values of all resources, even if they haven't been updated by the + // computation. We must also ask the compiler to keep compile-time constant + // outputs as part of the generated computation, for the same reason. + // TODO(phawkins): consider adding loop-invariant inputs to XLA's While() + // operator. + XlaCompiler::CompileOptions body_options; + body_options.use_tuple_arg = use_tuple_arg; + body_options.return_updated_values_for_all_resources = true; + body_options.resolve_compile_time_constants = false; + XlaCompiler::CompilationResult body; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, + arguments, &body)); + + // We must use a static shape for parameters to an XLA compilation. However, + // we may not know the shape of a TensorArray if it is first written inside + // the loop. Ideally we would require the user to provide a static shape, + // but this is not always easy. + // So if uninitialized resource are used by the loop body, we compile the + // body function twice: + // 1) once with uninitialized resource inputs. We discard the computation + // but we assume resource shapes reach a fixpoint after one iteration. + // So we can use the output shapes of the resource as the "true" shapes. + // 2) again with the "correct" input shapes determined by (1). + if (has_uninitialized_vars) { + // Initializes any uninitialized resource with zero values of the + // shape determined by the first compilation. + for (int i = 0; i < body.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + XlaCompiler::Argument& arg = arguments[update.input_index]; + if (!arg.initialized) { + VLOG(2) << "Update shape for argument " << update.input_index << " " + << xla::ShapeUtil::HumanString(update.shape); + arg.initialized = true; + arg.shape = update.shape; + + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index, &resource)); + + std::unique_ptr zero = + xla::Literal::CreateFromShape(update.shape); + resource->value = builder->ConstantLiteral(*zero); + } + } + // Recompile the body with the "correct" shapes. + VLOG(1) << "Recompiling body with non-placeholder shapes"; + body = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, + arguments, &body)); + } + + VLOG(1) << "Compiling condition"; + + XlaCompiler::CompileOptions cond_options; + cond_options.use_tuple_arg = use_tuple_arg; + cond_options.resolve_compile_time_constants = false; + XlaCompiler::CompilationResult cond; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, + arguments, &cond)); + + xla::Shape body_input_shape, cond_input_shape; + if (use_tuple_arg) { + body_input_shape = xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes); + cond_input_shape = xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes); + } else { + CHECK(!body.xla_input_shapes.empty()); + body_input_shape = body.xla_input_shapes[0]; + CHECK(!cond.xla_input_shapes.empty()); + cond_input_shape = cond.xla_input_shapes[0]; + } + + VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) + << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); + VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape) + << " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape); + + OP_REQUIRES(ctx, + xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape), + errors::InvalidArgument( + "Input shapes of loop body and condition do not match: ", + xla::ShapeUtil::HumanString(body_input_shape), " vs. ", + xla::ShapeUtil::HumanString(cond_input_shape))); + OP_REQUIRES( + ctx, xla::ShapeUtil::Compatible(body_input_shape, body.xla_output_shape), + errors::InvalidArgument( + "Input and output shapes of loop body do not match: ", + xla::ShapeUtil::HumanString(body_input_shape), " vs. ", + xla::ShapeUtil::HumanString(body.xla_output_shape))); + + xla::ComputationDataHandle data; + + int num_inputs = body.input_mapping.size(); + + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = body.input_mapping[i]; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + inputs[i] = resource->value; + } else { + inputs[i] = ctx->Input(i); + } + } + + xla::ComputationDataHandle init; + if (use_tuple_arg) { + init = builder->Tuple(inputs); + } else { + init = inputs[0]; + } + + VLOG(1) << "Building while loop"; + + xla::ComputationDataHandle while_result = + builder->While(*cond.computation, *body.computation, init); + + auto get_loop_output = [&](int i) { + if (use_tuple_arg) { + return builder->GetTupleElement(while_result, i); + } else { + return while_result; + } + }; + + // Sets non-variable outputs. + for (int i = 0; i < ctx->num_outputs(); ++i) { + if (ctx->input_type(i) != DT_RESOURCE) { + ctx->SetOutput(body.input_mapping[i], get_loop_output(i)); + } + } + + // Updates the values of any resource variables modified by the loop. + for (int i = 0; i < body.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); + if (update.modified) { + int pos = body.outputs.size() + i; + resource->value = get_loop_output(pos); + } + VLOG(2) << "Loop-carried variable: pos: " << update.input_index + << " name: " << resource->name << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + // Copies the identity of the resource variable from input to output + // unchanged, even if the variable was not modified. + ctx->op_kernel_context()->set_output( + update.input_index, + ctx->op_kernel_context()->input(update.input_index)); + } + + VLOG(1) << "Done building while loop"; +} + +REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h new file mode 100644 index 0000000000000000000000000000000000000000..67edebabf9f643a919d0f06c228e2d224a49a2af --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional iteration primitive. +// +// The inputs and outputs of the loop body must agree on the number, types, and +// shapes of the Tensors carried around the loop body. +// +// Computations in while loops may read from and write to resource variables. +// Resource variables may be passed as arguments to a function's body and +// condition functions. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the body's output. This ensures the loop body's input and output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +// +// For example, suppose we have a loop body with arguments: +// DT_INT32, DT_RESOURCE (pointing to a DT_BOOL var), DT_FLOAT +// and return values +// DT_INT32, DT_FLOAT +// It is an error for the body to return DT_RESOURCE values. +// +// The body will be lowered into an XLA computation that takes and returns a +// tuple with XLA type (I32, F32, PRED). Note the resource variable appears at +// the end of both the loop body's input and output argument lists. +class XlaWhileOp : public XlaOpKernel { + public: + explicit XlaWhileOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + NameAttrList cond_name_attr_; + NameAttrList body_name_attr_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 1f2bc01cf4a48b37de585c55b781c239ee4b8f2a..576cd9bf9abb43e29d9eb8f706e0f42ac2d038e9 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -27,13 +27,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { TF_RETURN_IF_ERROR(TensorShapeToXLAShape( host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape())); - xla::LiteralUtil::Reserve(host_tensor.NumElements(), literal); + literal->Reserve(host_tensor.NumElements()); // memcpy over the payload ... // TODO(phawkins): handle string types. size_t total_bytes = host_tensor.TotalBytes(); if (total_bytes > 0) { - void* dst_ptr = xla::LiteralUtil::MutableInternalData(literal); + void* dst_ptr = literal->MutableInternalData(); const void* src_ptr = DMAHelper::base(&host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } @@ -51,11 +51,12 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, " to tensor of type ", DataTypeString(target_type)); } - TensorShape shape = XLAShapeToTensorShape(literal.shape()); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); size_t total_bytes = host_tensor->TotalBytes(); if (total_bytes > 0) { - const void* src_ptr = xla::LiteralUtil::InternalData(literal); + const void* src_ptr = literal.InternalData(); void* dst_ptr = DMAHelper::base(host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 56993bc58534d1225f9177719804a69f561b3a06..f3d6787daaa1165b28ce63dfd501533fa0963edd 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -27,7 +27,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector int64_values = {1, 2, 3}; std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); + xla::Literal::CreateR1(gtl::ArraySlice(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -48,7 +48,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector int32_values = {10, 11}; std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); + xla::Literal::CreateR1(gtl::ArraySlice(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a2bd06861d5f383e3497a386b42a2e5a4035f1ea --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -0,0 +1,38 @@ +package( + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + +cc_library( + name = "functional_ops", + srcs = ["functional_ops.cc"], + deps = [ + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_gen_op_wrapper_py( + name = "gen_functional_ops", + out = "gen_functional_ops.py", + deps = [ + ":functional_ops", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..38bcaa32278c4acf212881b10d66bb67b807a21c --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/functional_ops.cc @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("XlaWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index f5ecb51a5b77e36e606ed1c48b8e2dbe76de0074..9d1992205b02665b99b1bd15b7b65a1fb8c35a51 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -24,12 +24,18 @@ limitations under the License. namespace tensorflow { // Convert an XLA Shape into the equivalent TensorFlow shape. -TensorShape XLAShapeToTensorShape(const xla::Shape& shape) { - TensorShape tensor_shape; +Status XLAShapeToTensorShape(const xla::Shape& shape, + TensorShape* tensor_shape) { + if (xla::ShapeUtil::IsTuple(shape)) { + return errors::InvalidArgument("XLA shape ", + xla::ShapeUtil::HumanString(shape), + " cannot be converted to a TensorShape"); + } + *tensor_shape = TensorShape(); for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) { - tensor_shape.AddDim(shape.dimensions(i)); + tensor_shape->AddDim(shape.dimensions(i)); } - return tensor_shape; + return Status::OK(); } // Convert a TensorShape into the equivalent XLA Shape proto. diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 516dd636a970f78fda363a0b13961b8244dc2cd9..58240b9c965a194b9380ac7cd477ce7344e5ebe3 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -24,8 +24,10 @@ limitations under the License. namespace tensorflow { -// Convert an XLA Shape into the equivalent TensorFlow shape. -TensorShape XLAShapeToTensorShape(const xla::Shape& shape); +// Convert an XLA Shape into the equivalent TensorFlow shape. May fail since +// not all XLA shapes can be represented as TensorShapes. +Status XLAShapeToTensorShape(const xla::Shape& shape, + TensorShape* tensor_shape); // Convert a TensorShape into the equivalent XLA Shape proto. Unlike Tensorflow, // XLA shapes include the type. Not all `dtype` values can be represented by diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c6c9a91b6d2fb47f6dee1c347e9b852f1eea3ec --- /dev/null +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/test_util.h" + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { + +Status InstantiateFunctionForTest(const string& name, + const FunctionLibraryDefinition& library, + InstantiationResultForTest* result) { + const FunctionDef* fdef = library.Find(name); + TF_RET_CHECK(fdef != nullptr); + + auto get_func_sig = [&library](const string& op, const OpDef** sig) { + return library.LookUpOpDef(op, sig); + }; + InstantiationResult inst; + TF_RETURN_IF_ERROR( + InstantiateFunction(*fdef, AttrSlice(), get_func_sig, &inst)); + result->arg_types = inst.arg_types; + result->ret_types = inst.ret_types; + for (NodeDef& n : inst.nodes) { + *result->gdef.add_node() = std::move(n); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e6e4ae92ed23f3fca0f59b131dc73152e0947b72 --- /dev/null +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for tests. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Same as InstantiationResult, but has a GraphDef instead of just nodes. +struct InstantiationResultForTest { + DataTypeVector arg_types; + DataTypeVector ret_types; + GraphDef gdef; +}; + +// Instantiates a function, producing a GraphDef to compare against the +// expected graph. +Status InstantiateFunctionForTest(const string& name, + const FunctionLibraryDefinition& library, + InstantiationResultForTest* result); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 75630bee3961243b2389274f0f98200ee3a0a7eb..ec28bdccda47a326a0f60f2f73e8837b68e668cb 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -64,26 +64,36 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -struct XlaVariable { - // If this variable is visible externally, what was its argument number? +// Represents a resource, such as a Variable or TensorArray. +struct XlaResource { + enum Kind { + kInvalid, + kVariable, + kTensorArray, + kStack, + }; + + Kind kind = kInvalid; + + // If this resource is visible externally, what was its argument number? int arg_num = -1; - // A descriptive name for the variable, used in error messages. + // A descriptive name for the resource, used in error messages. string name; - // Current type and value of the variable. Uninitialized variables are + // Current type and value of the resource. Uninitialized resources are // represented by a default (zero) handle and type DT_INVALID. - // While the type of a variable is notionally fixed during execution, when - // a variable is first initialized we do not yet know its type, so we keep + // While the type of a resource is notionally fixed during execution, when + // a resource is first initialized we do not yet know its type, so we keep // track of its type dynamically. DataType type = DT_INVALID; xla::ComputationDataHandle value; - // Value of the variable at computation entry. Used to detect which + // Value of the resource at computation entry. Used to detect which // variables have new values that need to be written back. xla::ComputationDataHandle initial_value; - // We treat TensorArrays as a Variable with some extra metadata. + // TensorArray-specific fields // 'tensor_array_size' stores the expected size of the TensorArray. We need // to store this since sometimes TensorArrays must be initialized lazily since @@ -91,10 +101,10 @@ struct XlaVariable { int64 tensor_array_size = -1; // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes - // to an XlaVariable containing the gradient TensorArrays. We store a pointer + // to an XlaResource containing the gradient TensorArrays. We store a pointer // here since there should only be one gradient TensorArray per 'source' // string, irrespective of the number of calls to TensorArrayGrad. - std::unordered_map tensor_array_gradient; + std::unordered_map tensor_array_gradient; }; // A XlaExpression wraps an XLA computation. Each Tensor on an @@ -115,8 +125,8 @@ class XlaExpression { bool has_constant_value() const { return has_constant_value_; } const Tensor& constant_value() const { return constant_value_; } - void set_variable(XlaVariable* variable) { variable_ = variable; } - XlaVariable* variable() const { return variable_; } + void set_resource(XlaResource* resource) { resource_ = resource; } + XlaResource* resource() const { return resource_; } private: // The XLA handle of the expression's computation. @@ -128,7 +138,7 @@ class XlaExpression { bool has_constant_value_ = false; Tensor constant_value_; - XlaVariable* variable_ = nullptr; // Not owned. + XlaResource* resource_ = nullptr; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 580ce3d802e71ef99903321fff2bc7374d0a9470..11a62b23aeac5028ef3384b0ec6a07018b7a3cbf 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -59,9 +60,11 @@ Status CheckSignature(const DataTypeVector& types, bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, type, shape, name, tensor_array_size) != - std::tie(other.kind, other.type, other.shape, other.name, - other.tensor_array_size)) { + if (std::tie(kind, type, name, tensor_array_size) != + std::tie(other.kind, other.type, other.name, other.tensor_array_size)) { + return false; + } + if (!xla::ShapeUtil::Equal(shape, other.shape)) { return false; } if (constant_value.shape() != other.constant_value.shape()) { @@ -85,6 +88,12 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) (*options_.populate_resource_manager)(device_->resource_manager()); } + local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), + FunctionDefLibrary{})); + local_flib_runtime_.reset(NewFunctionLibraryRuntime( + &device_mgr_, Env::Default(), device_, options.graph_def_version, + local_flib_def_.get(), OptimizerOptions(), + nullptr /* custom_kernel_creator */)); flib_runtime_.reset(NewFunctionLibraryRuntime( &device_mgr_, Env::Default(), device_, options.graph_def_version, options.flib_def, OptimizerOptions(), @@ -103,6 +112,18 @@ uint64 XlaCompiler::SignatureHash::operator()( return std::hash()(signature.first); } +static Status GetFunctionBody(const NameAttrList& function, + FunctionLibraryRuntime* flib_runtime, + const FunctionBody** fbody) { + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flib_runtime->Instantiate( + function.name(), AttrSlice(&function.attr()), &handle)); + + *fbody = flib_runtime->GetFunctionBody(handle); + TF_RET_CHECK(*fbody); + return Status::OK(); +} + Status XlaCompiler::CompileFunction( const XlaCompiler::CompileOptions& options, const NameAttrList& function, const std::vector& args, @@ -117,21 +138,21 @@ Status XlaCompiler::CompileFunction( return Status::OK(); } - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(flib_runtime_->Instantiate( - function.name(), AttrSlice(&function.attr()), &handle)); - - const FunctionBody* fbody = flib_runtime_->GetFunctionBody(handle); - CHECK(fbody); + const FunctionBody* fbody; + if (!GetFunctionBody(function, local_flib_runtime_.get(), &fbody).ok()) { + TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_.get(), &fbody)); + } TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); - if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_function_input_", function_id), *graph); + if (VLOG_IS_ON(2)) { + VLOG(2) << "XlaCompiler::CompileFunction: " + << dump_graph::DumpGraphToFile( + strings::StrCat("xla_compile_function_", function_id), + *graph); } // Optimize the graph before running the compiler. @@ -143,12 +164,6 @@ Status XlaCompiler::CompileFunction( optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(), /*device=*/nullptr, &graph); - if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_function_optimized_", function_id), - *graph); - } - VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); @@ -249,35 +264,37 @@ Status BuildArguments(const std::vector& args, std::vector* input_shapes) { context_args->resize(args.size()); - // Argument numbers of arguments and variables that are to be passed to the + // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. - std::vector parameters, variables; + std::vector parameters, resources; parameters.reserve(args.size()); - variables.reserve(args.size()); + resources.reserve(args.size()); for (std::vector::size_type i = 0; i < args.size(); ++i) { XlaContext::Argument& context_arg = (*context_args)[i]; + context_arg.kind = args[i].kind; context_arg.name = args[i].name; context_arg.value.constant_value = args[i].constant_value; context_arg.value.type = args[i].type; switch (args[i].kind) { case XlaCompiler::Argument::kVariable: - variables.push_back(i); - context_arg.is_variable = true; - context_arg.value.is_constant = false; + case XlaCompiler::Argument::kTensorArray: + case XlaCompiler::Argument::kStack: + context_arg.is_resource = true; + if (args[i].initialized) { + resources.push_back(i); + context_arg.value.is_constant = false; + } else { + context_arg.value.is_constant = true; + } context_arg.tensor_array_size = args[i].tensor_array_size; break; case XlaCompiler::Argument::kParameter: parameters.push_back(i); context_arg.value.is_constant = false; break; - case XlaCompiler::Argument::kUninitializedVariable: - context_arg.is_variable = true; - context_arg.value.is_constant = true; - context_arg.tensor_array_size = args[i].tensor_array_size; - break; case XlaCompiler::Argument::kConstant: context_arg.value.is_constant = true; break; @@ -288,7 +305,7 @@ Status BuildArguments(const std::vector& args, // Append parameters containing variable values after the other runtime // parameters. - parameters.insert(parameters.end(), variables.begin(), variables.end()); + parameters.insert(parameters.end(), resources.begin(), resources.end()); if (parameters.empty()) { return Status::OK(); } @@ -298,10 +315,7 @@ Status BuildArguments(const std::vector& args, for (std::vector::size_type i = 0; i < input_shapes->size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; // Computes the shapes of non-constant arguments. - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type)); - xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(), - &(*input_shapes)[i]); + (*input_shapes)[i] = arg.shape; (*input_mapping)[i] = parameters[i]; } @@ -329,22 +343,22 @@ Status BuildArguments(const std::vector& args, // variable states, generated by the symbolic evaluation. // If `has_side_effects` is true, the computation has side effects and should be // built even if it has no outputs. -// If `return_updated_values_for_all_variables` is true, all variables will be -// included in `variable_updates`, regardless of whether their value changed. +// If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*variable_updates` to a description of variables whose values are +// Sets `*resource_updates` to a description of resources whose values are // written by the computation; the variable writes are the last -// `variable_updates.size()` return values from the computation. Each entry in -// `variable_updates` is a (input_index, type) pair, where `input_index` is the +// `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a (input_index, type) pair, where `input_index` is the // index of a resource variable argument to the computation, and `type` is the // type of the final output. Status BuildComputation( const std::vector& retvals, - const std::vector>& variables, - bool has_side_effects, bool return_updated_values_for_all_variables, + const std::vector>& resources, + bool has_side_effects, bool return_updated_values_for_all_resources, xla::ComputationBuilder* builder, xla::Computation* computation, - int* num_nonconst_outputs, - std::vector* variable_updates) { + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* resource_updates) { std::vector elems; elems.reserve(retvals.size()); for (const XlaContext::HandleOrConstant& retval : retvals) { @@ -354,24 +368,24 @@ Status BuildComputation( } *num_nonconst_outputs = elems.size(); - // Add return values for variables whose values have changed. - std::vector arg_vars; - arg_vars.reserve(variables.size()); - for (const auto& var : variables) { + // Add return values for resources whose values have changed. + std::vector arg_vars; + arg_vars.reserve(resources.size()); + for (const auto& var : resources) { if (var->arg_num >= 0) { arg_vars.push_back(var.get()); } } std::sort(arg_vars.begin(), arg_vars.end(), - [](const XlaVariable* a, const XlaVariable* b) { + [](const XlaResource* a, const XlaResource* b) { return a->arg_num < b->arg_num; }); - for (const XlaVariable* var : arg_vars) { + for (const XlaResource* var : arg_vars) { bool modified = var->value.handle() != var->initial_value.handle(); - if (return_updated_values_for_all_variables || modified) { - variable_updates->emplace_back(); - XlaCompiler::VariableUpdate& update = variable_updates->back(); + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); update.input_index = var->arg_num; update.type = var->type; update.modified = modified; @@ -379,6 +393,7 @@ Status BuildComputation( } } + *num_computation_outputs = elems.size(); if (!elems.empty() || has_side_effects) { // Builds a empty tuple return value for computations that have side effects // but have no return values. @@ -401,6 +416,18 @@ Status BuildComputation( return Status::OK(); } +void AssignMajorToMinorLayout(xla::Shape* shape) { + if (xla::ShapeUtil::IsTuple(*shape)) { + for (xla::Shape& elem_shape : *shape->mutable_tuple_shapes()) { + AssignMajorToMinorLayout(&elem_shape); + } + } else { + auto& minor_to_major = *shape->mutable_layout()->mutable_minor_to_major(); + minor_to_major.Resize(xla::ShapeUtil::Rank(*shape), 0); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } +} + } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, @@ -410,13 +437,24 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + if (VLOG_IS_ON(2)) { + VLOG(2) << "XlaCompiler::CompileGraph: " + << dump_graph::DumpGraphToFile( + strings::StrCat("xla_compile_graph_", name), *graph); + } + // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); + // Converts Tensorflow's graph control-flow constructs into functional + // control-flow that can be compiled into XLA code. + TF_RETURN_IF_ERROR( + FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); + xla::ComputationBuilder builder(client(), name); XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options_.resolve_compile_time_constants); + options.resolve_compile_time_constants); core::ScopedUnref context_unref(context); result->tuple_arg = options.use_tuple_arg; @@ -431,12 +469,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, flib_runtime_.get(), NextStepId())); int num_nonconst_outputs; + int num_computation_outputs; result->computation = std::make_shared(); TF_RETURN_IF_ERROR(BuildComputation( - context->retvals(), context->variables(), context->has_side_effects(), - options.return_updated_values_for_all_variables, &builder, - result->computation.get(), &num_nonconst_outputs, - &result->variable_updates)); + context->retvals(), context->resources(), context->has_side_effects(), + options.return_updated_values_for_all_resources, &builder, + result->computation.get(), &num_computation_outputs, + &num_nonconst_outputs, &result->resource_updates)); result->requires_runtime_context = context->has_context_parameter(); @@ -473,23 +512,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); - auto num_computation_outputs = - (xla::ShapeUtil::IsTuple(result->xla_output_shape)) - ? xla::ShapeUtil::TupleElementCount(result->xla_output_shape) - : 1; // Tensorflow expects a major-to-minor order of results. - if (1 == num_computation_outputs) { - xla::Shape& s = result->xla_output_shape; - auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major(); - minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0); - std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); - } else { - for (xla::Shape& s : *result->xla_output_shape.mutable_tuple_shapes()) { - auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major(); - minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0); - std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); - } - } + AssignMajorToMinorLayout(&result->xla_output_shape); // Converts the output shapes to TensorShapes. int computation_output = 0; @@ -501,26 +525,26 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, OutputDescription& output = result->outputs[i]; output.is_constant = false; if (num_computation_outputs > 1) { - output.shape = - XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( - result->xla_output_shape, computation_output)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape( + xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, + computation_output), + &output.shape)); } else { - output.shape = XLAShapeToTensorShape(result->xla_output_shape); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(result->xla_output_shape, &output.shape)); } ++computation_output; } } - for (std::vector::size_type i = 0; - i < result->variable_updates.size(); ++i) { + for (std::vector::size_type i = 0; + i < result->resource_updates.size(); ++i) { if (num_computation_outputs > 1) { - result->variable_updates[i].shape = - XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( - result->xla_output_shape, computation_output)); + result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape( + result->xla_output_shape, computation_output); } else { CHECK_EQ(0, computation_output); - result->variable_updates[i].shape = - XLAShapeToTensorShape(result->xla_output_shape); + result->resource_updates[i].shape = result->xla_output_shape; } ++computation_output; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 131430553252e2b62315c6388a53058bdf20eb7f..7251c92edb2b56e2d738abc4570e74e4c9dc6c62 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -85,27 +85,31 @@ class XlaCompiler { // Argument is a compile-time constant. No associated runtime parameter. kConstant, - // Argument is a variable that has not been initialized yet. No associated - // runtime parameter. - kUninitializedVariable, - - // Argument is a variable that already has a value set. Expects a runtime - // parameter containing the current value. + // Argument is a Variable resource. Has an associated runtime parameter + // iff `initialized` is true. kVariable, + // Argument is a TensorArray resource. Has an associated runtime parameter + // iff `initialized` is true. + kTensorArray, + + // Argument is a Stack resource. Has an associated runtime parameter + // iff `initialized` is true. + kStack, + // Argument is a run-time parameter. kParameter, }; Kind kind = kInvalid; - // The type of the argument. If the argument is a resource variable, this + // The type of the argument. If the argument is a resource, this // is the type of the variable's value, not DT_RESOURCE. DataType type; - // The shape of the argument. If the argument is a resource variable, this - // is the shape of the variable's value. - TensorShape shape; + // The shape of the argument. If the argument is a resource, this is the + // shape of the resource's value. + xla::Shape shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -114,8 +118,11 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; - // For a kVariable or kUninitializedVariable corresponding to a TensorArray, - // what is the tensor array's declared size? + // For a kVariable or kTensorArray, has this resource been initialized? + bool initialized = false; + + // For a kTensorArray, what is the array's declared size? (Used for lazy + // initialization.) int64 tensor_array_size = -1; bool operator==(const Argument& other) const; @@ -133,23 +140,23 @@ class XlaCompiler { }; // Describes a variable write side effect of the computation. - struct VariableUpdate { + struct ResourceUpdate { // Index of the input that contains the variable resource to write to. int input_index; // Type and shape of the tensor to be written back. DataType type; - TensorShape shape; + xla::Shape shape; // Was the value of the variable modified by the computation? - // (Always true, unless `return_updated_values_for_all_variables` is true.) + // (Always true, unless `return_updated_values_for_all_resources` is true.) bool modified; }; struct CompilationResult { // Vector that maps from the parameters of the XLA computation to their // original argument positions. To handle compile-time constant inputs and - // variables, the parameters to the XLA computation may be a subset of the + // resources, the parameters to the XLA computation may be a subset of the // original arguments, and are not necessarily in the same order.) std::vector input_mapping; @@ -172,10 +179,10 @@ class XlaCompiler { // containing both constant and non-constant results. std::vector outputs; - // Variables whose values were updated by the computation, ordered - // by return value position. Variable updates follow the non-constant + // Resources whose values were updated by the computation, ordered + // by return value position. Resource updates follow the non-constant // results in the outputs of XLA computation. - std::vector variable_updates; + std::vector resource_updates; // The XLA computation built from the tensorflow subgraph. May be null // if the output consists solely of compile-time constants. @@ -206,12 +213,6 @@ class XlaCompiler { // stored in device memory. bool local_executable_has_hybrid_result = false; - // If 'resolve_compile_time_constants' is true, then outputs of a - // computation that are known to be compile-time constants will be returned - // as Tensors at compile-time, rather than as run-time outputs of the - // computation. - bool resolve_compile_time_constants = true; - // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects @@ -229,12 +230,18 @@ class XlaCompiler { // arguments; if false, each argument gets its own parameter. bool use_tuple_arg = false; - // If 'return_updated_values_for_all_variables' is true, then updated - // values of all resource variables arguments will be included in the - // 'variable_updates' of the computation, even if the variable was not + // If 'return_updated_values_for_all_resources' is true, then updated + // values of all resource resources arguments will be included in the + // 'resource_updates' of the computation, even if the resource was not // modified by the computation. Used when compiling loop bodies to ensure // the input and output signatures match. - bool return_updated_values_for_all_variables = false; + bool return_updated_values_for_all_resources = false; + + // If 'resolve_compile_time_constants' is true, then outputs of a + // computation that are known to be compile-time constants will be returned + // as Tensors at compile-time, rather than as run-time outputs of the + // computation. + bool resolve_compile_time_constants = true; }; // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. @@ -294,6 +301,12 @@ class XlaCompiler { XlaCompilationDevice* device_; // Owned by device_mgr_ DeviceMgr device_mgr_; + // To avoid copying the client's function library, use a local function + // library and runtime for functions created as part of the functionalize + // control flow transformation. + std::unique_ptr local_flib_def_; + std::unique_ptr local_flib_runtime_; + std::unique_ptr flib_runtime_; struct SignatureHash { diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 58d74057d101cdef89fca24ec6c0858291d825fa..42bbccd1d365f10d8d8d1bd839b5b4de57fb1656 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -149,10 +150,10 @@ TEST_F(XlaCompilerTest, Simple) { std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = TensorShape({2}); + args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; - args[1].shape = TensorShape({2}); + args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); // Compiles the graph. XlaCompiler compiler(DefaultOptions()); @@ -163,9 +164,9 @@ TEST_F(XlaCompilerTest, Simple) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -179,7 +180,7 @@ TEST_F(XlaCompilerTest, Simple) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected_literal = - xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal::CreateR1({4, 143}); xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } @@ -201,21 +202,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = TensorShape({2}); + args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + XlaCompiler::Options options = DefaultOptions(); + XlaCompiler compiler(options); { // Compiles the graph, with resolve_compile_time_constants enabled. - XlaCompiler::Options options = DefaultOptions(); - options.resolve_compile_time_constants = true; - XlaCompiler compiler(options); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = true; XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), - "constants", std::move(graph_copy), args, - &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph_copy), args, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_TRUE(result.outputs[0].is_constant); @@ -225,7 +226,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -236,23 +237,20 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected_literal = - xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal::CreateR1({-7, -42}); xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } { // Compiles the graph, with resolve_compile_time_constants disabled. - XlaCompiler::Options options = DefaultOptions(); - options.resolve_compile_time_constants = false; - XlaCompiler compiler(options); - std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = false; XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), - "constants", std::move(graph_copy), args, - &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph_copy), args, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_FALSE(result.outputs[0].is_constant); @@ -260,7 +258,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -270,12 +268,11 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR0(7); + std::unique_ptr expected0 = xla::Literal::CreateR0(7); std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); } } @@ -294,7 +291,7 @@ TEST_F(XlaCompilerTest, ResourceManager) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = TensorShape({2}); + args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); DummyResourceForTest* resource = new DummyResourceForTest(); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 4440b530696db6125e0af0606be49e2d834dbd9f..d4d493b456f668ecfbdd0164c573b9ae2aa810e9 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -129,16 +129,18 @@ void XlaContext::AddSideEffects() { xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateVariable(int arg_num, string name, DataType type, +Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, + string name, DataType type, const xla::ComputationDataHandle& handle, - XlaVariable** variable) { - variables_.emplace_back(new XlaVariable); - *variable = variables_.back().get(); - XlaVariable& var = **variable; - var.arg_num = arg_num; - var.name = std::move(name); - var.type = type; - var.initial_value = var.value = handle; + XlaResource** resource) { + resources_.emplace_back(new XlaResource); + *resource = resources_.back().get(); + XlaResource& r = **resource; + r.kind = kind; + r.arg_num = arg_num; + r.name = std::move(name); + r.type = type; + r.initial_value = r.value = handle; return Status::OK(); } @@ -170,27 +172,6 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) { - return LookupOrCreate(type, &sigmoid_func_, [this, type] { - const string type_string = DataTypeString(type); - VLOG(1) << "Building Sigmoid() for " << type_string; - xla::ComputationBuilder b(builder()->client(), - "sigmoid<" + type_string + ">"); - xla::PrimitiveType xla_type; - TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - // Clamp the inputs to the range [-18, 18] since anything outside - // this range is 0.0f or 1.0f in single-precision. We must clamp the range - // of x to avoid incorrect outputs due to fast-math optimizations for large - // negative x. - x = b.Clamp(XlaHelpers::IntegerLiteral(&b, type, -18), x, - XlaHelpers::IntegerLiteral(&b, type, 18)); - auto one = XlaHelpers::One(&b, type); - b.Div(one, b.Add(b.Exp(b.Neg(x)), one)); - return b.Build().ConsumeValueOrDie(); - }); -} - const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function& create) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3978baaf637b4948510eafe37de94a383a87ddc3..544921b9e38fb52e70b9f67ba10f7c79dc53c657 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -52,11 +52,13 @@ class XlaContext : public ResourceBase { }; struct Argument { - // Descriptive name for the variable, for use in error messages. + XlaCompiler::Argument::Kind kind; + + // Descriptive name for the resource, for use in error messages. string name; - // Is this a variable? - bool is_variable = false; + // Is this a resource? + bool is_resource = false; HandleOrConstant value; @@ -106,15 +108,15 @@ class XlaContext : public ResourceBase { bool has_side_effects() const { return has_side_effects_; } - // Creates a variable with variable `variable_id` and initial type `type` and + // Creates a resource with resource `kind` and initial type `type` and // value `handle`. `name` is a descriptive name for use in error messages. - // Fails if the variable already exists. - Status CreateVariable(int arg_num, string name, DataType type, - const xla::ComputationDataHandle& handle, - XlaVariable** variable); + // Fails if the resource already exists. + Status CreateResource(XlaResource::Kind kind, int arg_num, string name, + DataType type, const xla::ComputationDataHandle& handle, + XlaResource** resource); - const std::vector>& variables() { - return variables_; + const std::vector>& resources() { + return resources_; } // Get an XLA lambda to compute Max. This is cached in the @@ -127,11 +129,6 @@ class XlaContext : public ResourceBase { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); - // Get an XLA lambda to compute Sigmoid. This is cached in the - // XlaContext since it may be used by multiple Ops. There is a - // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateSigmoid(const DataType type); - // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -166,8 +163,8 @@ class XlaContext : public ResourceBase { // Does the computation have side effects, i.e., Send() calls? bool has_side_effects_ = false; - // Holds ownership of variables. The variables are not ordered. - std::vector> variables_; + // Holds ownership of resources. The resources are not ordered. + std::vector> resources_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f060f8f2f178b2bc56caf7a3df9df32c8a407473..3af866f9be516beae7e6fa64b5a4cf1fef843f67 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -30,28 +30,28 @@ xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::MinValue(type)); + return b->ConstantLiteral(xla::Literal::MinValue(type)); } xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::MaxValue(type)); + return b->ConstantLiteral(xla::Literal::MaxValue(type)); } xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::Zero(type)); + return b->ConstantLiteral(xla::Literal::Zero(type)); } xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::One(type)); + return b->ConstantLiteral(xla::Literal::One(type)); } xla::ComputationDataHandle XlaHelpers::IntegerLiteral( @@ -61,28 +61,28 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { case xla::U8: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::U32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::U64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S8: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::F32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::F64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -91,7 +91,7 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::F16: literal = - *xla::LiteralUtil::CreateR0(static_cast(value)); + *xla::Literal::CreateR0(static_cast(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; @@ -205,4 +205,13 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, return Status::OK(); } +xla::ComputationDataHandle XlaHelpers::PadWithZeros( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + int count) { + xla::ComputationDataHandle zero = builder->ConstantR1({0}); + std::vector xs(count + 1, zero); + xs[0] = builder->Reshape(x, {1}); + return builder->ConcatInDim(xs, 0); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index a141ee05c13ed2e09fab69946ba400ab6cd628a9..2166ce363608ea65ba8cd9db856aff9ee2715005 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -77,6 +77,11 @@ class XlaHelpers { const xla::ComputationDataHandle& on_value, const xla::ComputationDataHandle& off_value, xla::ComputationDataHandle* one_hot); + + // Pads 'x' with 'count' zeros. 'x' must have 1 element. + static xla::ComputationDataHandle PadWithZeros( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + int count); }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 3272b1efa153c0ecab720583277175b81fe59509..c5a68e05d9e1dfa3ed1c648e95d3690fadef8b51 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -39,7 +39,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); CHECK(expression->handle().handle() != 0 || - expression->variable() != nullptr); + expression->resource() != nullptr); VLOG(1) << "Fetched T" << expression->handle().handle(); return expression; } @@ -144,9 +144,9 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S32) { - *out = xla::LiteralUtil::Get(literal, {}); + *out = literal.Get({}); } else if (literal.shape().element_type() == xla::S64) { - *out = xla::LiteralUtil::Get(literal, {}); + *out = literal.Get({}); } else { return errors::InvalidArgument("value must be either int32 or int64"); } @@ -168,11 +168,11 @@ static Status LiteralToInt64Vector(const xla::Literal& literal, int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { for (int64 i = 0; i < size; ++i) { - out->push_back(xla::LiteralUtil::Get(literal, {i})); + out->push_back(literal.Get({i})); } } else if (literal.shape().element_type() == xla::S64) { for (int64 i = 0; i < size; ++i) { - out->push_back(xla::LiteralUtil::Get(literal, {i})); + out->push_back(literal.Get({i})); } } else { return errors::InvalidArgument("value must be either int32 or int64"); @@ -252,8 +252,9 @@ Status XlaOpKernelContext::ReadVariableInput( int index, xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -262,22 +263,13 @@ Status XlaOpKernelContext::ReadVariableInput( return Status::OK(); } -string XlaOpKernelContext::VariableDebugString(int index) { - const Tensor& tensor = context_->input(index); - const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); - if (!variable) { - return ""; - } - return variable->name; -} - Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -287,7 +279,8 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, if (!shape_or_status.ok()) { return shape_or_status.status(); } - *shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); return Status::OK(); } @@ -304,10 +297,11 @@ void XlaOpKernelContext::SetOutput(int index, // The step's default allocator is the dummy XlaCompilationAllocator which // simply allocates a metadata buffer to hold the expression to which it // corresponds. - OP_REQUIRES_OK( - context_, - context_->allocate_output( - index, XLAShapeToTensorShape(*shape.ValueOrDie()), &output)); + TensorShape tensor_shape; + OP_REQUIRES_OK(context_, + XLAShapeToTensorShape(*shape.ValueOrDie(), &tensor_shape)); + OP_REQUIRES_OK(context_, + context_->allocate_output(index, tensor_shape, &output)); // The expression is stored in the tensor's data buffer. Fill in the // fields now. @@ -337,33 +331,34 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } -void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) { +void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Tensor* output = nullptr; - // The shape of the output tensor is the shape of the variable resource - // (i.e., a scalar), not the shape of the variable's value. + // The shape of the output tensor is the shape of the resource itself + // (i.e., a scalar), not the shape of the resource's value. OP_REQUIRES_OK(context_, context_->allocate_output(index, TensorShape(), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_variable(variable); + expression->set_resource(resource); } -Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) { +Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { const XlaExpression* expression = CastExpressionFromTensor(context_->input(index)); - TF_RET_CHECK(expression->variable() != nullptr); - *variable = expression->variable(); + TF_RET_CHECK(expression->resource() != nullptr); + *resource = expression->resource(); return Status::OK(); } Status XlaOpKernelContext::AssignVariable( - int index, DataType type, const xla::ComputationDataHandle& handle) { + int input_index, DataType type, const xla::ComputationDataHandle& handle) { TF_RET_CHECK(handle.handle() != 0); SetOpHasSideEffects(); const XlaExpression* expression = - CastExpressionFromTensor(context_->input(index)); - XlaVariable* variable = expression->variable(); + CastExpressionFromTensor(context_->input(input_index)); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (!((variable->type == DT_INVALID && type != DT_INVALID) || (variable->type == type))) { return errors::InvalidArgument( @@ -398,11 +393,6 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( return XlaContext::Get(context_).GetOrCreateAdd(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateSigmoid( - const DataType type) { - return XlaContext::Get(context_).GetOrCreateSigmoid(type); -} - XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a25774c3a6a4a7212d157766a23e73063c2deab8..30b794c8c198cae6bf3b11794b35049b729063e1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -148,6 +148,12 @@ class XlaOpKernelContext { // Variables + // Sets '*resource' to the resource associated with input `index`. + Status GetResourceInput(int index, XlaResource** resource); + + // Sets output 'index' to be a reference to resource 'resource'. + void SetResourceOutput(int index, XlaResource* resource); + // Sets `*type` and `*shape` to the current type and shape of a variable's // value. Status GetVariableTypeAndShape(int index, DataType* type, @@ -158,20 +164,10 @@ class XlaOpKernelContext { Status ReadVariableInput(int index, xla::ComputationDataHandle* value); // Assigns the value `handle` to the variable referenced by input - // `variable_index`. Marks the operator as having side effects. - Status AssignVariable(int variable_index, DataType type, + // `input_index`. Marks the operator as having side effects. + Status AssignVariable(int input_index, DataType type, const xla::ComputationDataHandle& handle); - // Sets '*variable' to the variable associated with input `index`. - Status GetVariableInput(int index, XlaVariable** variable); - - // Sets output 'index' to be a reference to variable 'variable'. Used - // to propagate resource variables through the compilation. - void SetVariableOutput(int index, XlaVariable* variable); - - // Returns a human-readable debug string describing 'variable_index'. - string VariableDebugString(int variable_index); - // Helper routines for the OP_REQUIRES macros void CtxFailure(Status s); void CtxFailureWithWarning(Status s); @@ -205,11 +201,6 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); - // Get an XLA lambda to compute Sigmoid. This is cached in the - // XlaContext since it may be used by multiple Ops. There is a - // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateSigmoid(const DataType type); - private: OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 1bb0d8528994b957ccebeabce8bc48227122e366..d059c7a23ef2955cdd1280d1ceff7fc39b625631 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -24,6 +24,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -34,11 +36,18 @@ const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT"; const char* const DEVICE_XLA_CPU = "XLA_CPU"; const char* const DEVICE_XLA_GPU = "XLA_GPU"; -// Is platform 'id' supported by XLA? -static bool IsPlatformSupported(perftools::gputools::Platform::Id id) { - auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id); - if (!platform.ok()) return false; - return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok(); +static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + NodeDef node_def; + node_def.set_name("_XlaLaunch-op"); + node_def.set_op("_XlaLaunch"); + string kernel_class_name; + TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, + &kernel_class_name)); + VLOG(1) << "LaunchOpHasKernelForDevice" + << " kernel_class_name: " << kernel_class_name; + return Status::OK(); } XlaOpRegistry::XlaOpRegistry() = default; @@ -75,7 +84,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; // GetCompilationDevice is called. static void* registration_init = [®istry]() { mutex_lock lock(registry.mutex_); - if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) { + if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_CPU]; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; @@ -83,7 +92,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; registration.enable_jit_by_default = false; registration.compile_resource_ops = false; } - if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) { + if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_GPU]; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 9a39cc96754fe8fef2c52e5de9626bcad30bf483..47d61a21a13b35050dff3d95c3856ee3f356f3c7 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -45,9 +45,10 @@ extern const char* const DEVICE_XLA_CPU; extern const char* const DEVICE_XLA_GPU; constexpr std::array kIntTypes = {{DT_INT32, DT_INT64}}; -constexpr std::array kFloatTypes = {{DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kNumericTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}}; +constexpr std::array kFloatTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE}}; +constexpr std::array kNumericTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}}; constexpr std::array kCpuAllTypes = { {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 2491cc3f7a2011827f4e093f287b525155153b71..e0a03a78f1d847ee03e136d46bdb28b0a085dc4c 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -46,21 +46,18 @@ xla_proto_library( ], ) -# This is a headers target that extra XLA devices can use to prevent -# circular dependencies. Devices that are compiled as separate shared -# objects can also use it to prevent linking of library code. -cc_header_only_library( - name = "xla_headers_lib", - visibility = ["//visibility:public"], +cc_library( + name = "execution_options_util", + srcs = [ + "execution_options_util.cc", + ], + hdrs = [ + "execution_options_util.h", + ], + visibility = [":friends"], deps = [ - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_evaluator", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:stream_executor_headers_lib", + ":xla_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", ], ) @@ -135,7 +132,10 @@ cc_library( cc_library( name = "statusor", srcs = ["statusor.cc"], - hdrs = ["statusor.h"], + hdrs = [ + "statusor.h", + "statusor_internals.h", + ], visibility = ["//visibility:public"], deps = [ ":status", @@ -171,7 +171,6 @@ cc_library( ":status", ":types", ":xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:util_flags", "//tensorflow/core:lib", ], ) @@ -226,7 +225,6 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", @@ -256,7 +254,6 @@ cc_test( ":shape_util", ":test", ":test_helpers", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/core:test_main", ], ) @@ -577,6 +574,7 @@ cc_test( srcs = ["reference_util_test.cc"], deps = [ ":array2d", + ":array3d", ":array4d", ":literal_util", ":reference_util", @@ -602,3 +600,17 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. +cc_header_only_library( + name = "xla_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_data_proto", + ":xla_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:stream_executor_headers_lib", + ], +) diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index d93f968f4d7a8c30129f4e14c4db06c25187cb45..4c7fce1aaf1faf4bd08bca38bc8eb2b47303b575 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -207,6 +207,18 @@ class Array4D { } } + // Invokes a callback with the (indices, value) for each cell in the 4D array. + void Each( + std::function, T)> f) const { + // We const_cast to be able to use the common non-const implementation, + // but prevent modification of the data by passing it by-value to the + // caller. + const_cast(this)->Each( + [&f](tensorflow::gtl::ArraySlice indices, T* value) { + f(indices, *value); + }); + } + // Fills all of the {p,z} with the array provided, which specifies {y,x}. void FillWithYX(const Array2D& value) { CHECK_EQ(value.height(), height()); diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 63c6d9ddaca5e9e336e29cd3b23cfd921d4ce9e7..a998b91c89d79ac5e354d2a3edf5fb78695d73cb 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -62,6 +62,7 @@ cc_library( deps = [ ":computation", ":global_data", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:service_interface", "//tensorflow/compiler/xla:status_macros", @@ -70,6 +71,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], @@ -114,7 +116,6 @@ cc_library( "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 454d0fbd9650c4d77a62b4c25a5407e36bd191f8..1799bbd3480daacc204b42f168a7f8e9149db58b 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -197,7 +199,10 @@ StatusOr> Client::Execute( ExecutionProfile* execution_profile) { ExecuteRequest request; *request.mutable_computation() = computation.handle(); - if (execution_options != nullptr) { + + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { *request.mutable_execution_options() = *execution_options; } for (GlobalData* argument : arguments) { @@ -298,7 +303,9 @@ StatusOr Client::ExecuteAsync( for (GlobalData* argument : arguments) { *request.add_arguments() = argument->handle(); } - if (execution_options != nullptr) { + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { *request.mutable_execution_options() = *execution_options; } @@ -376,9 +383,10 @@ StatusOr>> Client::DeconstructTuple( } StatusOr Client::GetComputationStats( - const Computation& computation) const { + const Computation& computation, const DebugOptions& debug_options) const { ComputationStatsRequest request; *request.mutable_computation() = computation.handle(); + *request.mutable_debug_options() = debug_options; ComputationStatsResponse response; VLOG(1) << "making computation stats request"; @@ -427,7 +435,10 @@ StatusOr Client::GetShape(const GlobalData& data) { StatusOr Client::ExecutionStatsAsString( const Computation& computation, const ExecutionProfile& profile) { - TF_ASSIGN_OR_RETURN(auto computation_stats, GetComputationStats(computation)); + TF_ASSIGN_OR_RETURN( + auto computation_stats, + GetComputationStats(computation, + legacy_flags::GetDebugOptionsFromFlags())); int64 total_flops = computation_stats.flop_count() + computation_stats.transcendental_count(); if (profile.compute_time_ns() > 0) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 797835160fa2850f108e85ff3147abffd9f86ad8..69d3642911fa8fe87ceb347d929e95ffd972615b 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -150,7 +150,7 @@ class Client { // Retrieves the statistics of the given computation. StatusOr GetComputationStats( - const Computation& computation) const; + const Computation& computation, const DebugOptions& debug_options) const; // Returns the Shape of the given array specified by 'data'. The shape // includes the Layout of the array as it is stored on the service. diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 8238261e1c90cadeda9005e437d684d3770bd67b..b1663bc815719c3da75b37593ac665b1f3493db8 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -23,6 +23,13 @@ limitations under the License. namespace xla { +LocalClientOptions::LocalClientOptions(perftools::gputools::Platform* platform, + int number_of_replicas, + int intra_op_parallelism_threads) + : platform_(platform), + number_of_replicas_(number_of_replicas), + intra_op_parallelism_threads_(intra_op_parallelism_threads) {} + LocalClientOptions& LocalClientOptions::set_platform( perftools::gputools::Platform* platform) { platform_ = platform; @@ -142,4 +149,12 @@ ClientLibrary::GetOrCreateCompileOnlyClient( return cl; } +/* static */ void ClientLibrary::DestroyLocalInstances() { + ClientLibrary& client_library = Singleton(); + tensorflow::mutex_lock lock(client_library.service_mutex_); + + client_library.local_instances_.clear(); + client_library.compile_only_instances_.clear(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 3ddd235d0efeeb78f49eafbf670d7c74a88960dd..a6f30d82e43587135697e76e8bc7d122edc0f602 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -43,13 +43,16 @@ namespace xla { // Options to configure the local client when it is created. class LocalClientOptions { public: + LocalClientOptions(perftools::gputools::Platform* platform = nullptr, + int number_of_replicas = 1, + int intra_op_parallelism_threads = -1); + // Set the platform backing the service, or nullptr for the default platform. LocalClientOptions& set_platform(perftools::gputools::Platform* platform); perftools::gputools::Platform* platform() const; // Set the number of replicas to use when compiling replicated - // programs. The default is -1 meaning that the value is read from - // the xla_replicas flag. + // programs. LocalClientOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; @@ -58,9 +61,9 @@ class LocalClientOptions { int intra_op_parallelism_threads() const; private: - perftools::gputools::Platform* platform_ = nullptr; - int number_of_replicas_ = -1; - int intra_op_parallelism_threads_ = -1; + perftools::gputools::Platform* platform_; + int number_of_replicas_; + int intra_op_parallelism_threads_; }; class ClientLibrary { @@ -90,6 +93,11 @@ class ClientLibrary { static StatusOr GetOrCreateCompileOnlyClient( perftools::gputools::Platform* platform = nullptr); + // Clears the local instance and compile only instance caches. The client + // pointers returned by the previous GetOrCreateLocalClient() or + // GetOrCreateCompileOnlyClient() invocations are not valid anymore. + static void DestroyLocalInstances(); + private: // Returns the singleton instance of ClientLibrary. static ClientLibrary& Singleton(); diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 37bf697683b0f5f61a1b915628920b0752116a32..212bcd27d29d6e3c06362344bd370d5ef24d6f56 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -971,6 +971,16 @@ ComputationDataHandle ComputationBuilder::Sign( return UnaryOp(UNOP_SIGN, operand); } +ComputationDataHandle ComputationBuilder::Cos( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_COS, operand); +} + +ComputationDataHandle ComputationBuilder::Sin( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_SIN, operand); +} + ComputationDataHandle ComputationBuilder::Tanh( const ComputationDataHandle& operand) { return UnaryOp(UNOP_TANH, operand); @@ -1411,6 +1421,72 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::BatchNormTraining( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& offset, float epsilon, int64 feature_index) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + BatchNormTrainingRequest request; + *request.mutable_operand() = operand; + *request.mutable_scale() = scale; + *request.mutable_offset() = offset; + request.set_epsilon(epsilon); + request.set_feature_index(feature_index); + + OpRequest op_request; + *op_request.mutable_batch_norm_training_request() = request; + *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); + + OpResponse response; + + VLOG(2) << "making BatchNormTraining request"; + + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::BatchNormInference( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& offset, const ComputationDataHandle& mean, + const ComputationDataHandle& variance, float epsilon, int64 feature_index) { + // TODO(b/62843645): Implement BatchNormInference. + NoteError(Unimplemented("BatchNormInference is not implemented yet.")); + return ComputationDataHandle(); +} + +ComputationDataHandle ComputationBuilder::BatchNormGrad( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& mean, const ComputationDataHandle& var, + const ComputationDataHandle& grad_output, float epsilon, + int64 feature_index) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + BatchNormGradRequest request; + *request.mutable_operand() = operand; + *request.mutable_scale() = scale; + *request.mutable_mean() = mean; + *request.mutable_variance() = var; + *request.mutable_grad_output() = grad_output; + request.set_epsilon(epsilon); + request.set_feature_index(feature_index); + + OpRequest op_request; + *op_request.mutable_batch_norm_grad_request() = request; + *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); + + OpResponse response; + + VLOG(2) << "making BatchNormGrad request"; + + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + ComputationDataHandle ComputationBuilder::CrossReplicaSum( const ComputationDataHandle& operand) { if (!first_error_.ok() || !PrepareComputation().ok()) { @@ -1487,6 +1563,28 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::ReducePrecision( + const ComputationDataHandle& operand, const int exponent_bits, + const int mantissa_bits) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ReducePrecisionRequest request; + *request.mutable_operand() = operand; + request.set_exponent_bits(exponent_bits); + request.set_mantissa_bits(mantissa_bits); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_reduce_precision_request() = request; + AddOpMetadata(&op_request); + OpResponse response; + + VLOG(2) << "making reduce-precision request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + void ComputationBuilder::Send(const ComputationDataHandle& operand, const ChannelHandle& handle) { if (!first_error_.ok() || !PrepareComputation().ok()) { diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 5cc73c28d03a097a4fd5b8d3a549ffdc43c6fcd3..94602bd473ffb138d29ca8df86388fe88cf5f312 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -510,6 +510,12 @@ class ComputationBuilder { // Enqueues a sign instruction onto the computation. ComputationDataHandle Sign(const ComputationDataHandle& operand); + // Enqueues a cosine instruction onto the computation. + ComputationDataHandle Cos(const ComputationDataHandle& operand); + + // Enqueues a sine instruction onto the computation. + ComputationDataHandle Sin(const ComputationDataHandle& operand); + // Enqueues a tanh instruction onto the computation. ComputationDataHandle Tanh(const ComputationDataHandle& operand); @@ -597,6 +603,11 @@ class ComputationBuilder { const Computation& body, const ComputationDataHandle& init); + // Enqueues a ReducePrecision node onto the computation. + ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, + const int exponent_bits, + const int mantissa_bits); + // Enqueues a Send node onto the computation, to send the given operand to // a Recv instruction that shares the same channel handle. void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); @@ -820,87 +831,80 @@ class ComputationBuilder { template ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantOp( - [value](Literal* literal) { LiteralUtil::PopulateR0(value, literal); }); + return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); }); } template ComputationDataHandle ComputationBuilder::ConstantR1( tensorflow::gtl::ArraySlice values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR1(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR1(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, NativeT value) { return ConstantOp([length, value](Literal* literal) { - LiteralUtil::PopulateWithValue(value, {length}, literal); + literal->PopulateWithValue(value, {length}); }); } inline ComputationDataHandle ComputationBuilder::ConstantR1( const tensorflow::core::Bitmap& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR1(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR1(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR2( std::initializer_list> values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR2(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR2(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR2FromArray2DWithLayout(values, layout, literal); + literal->PopulateR2FromArray2DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR2FromArray2D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR3FromArray3DWithLayout(values, layout, literal); + literal->PopulateR3FromArray3DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR3FromArray3D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal); + literal->PopulateR4FromArray4DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR4FromArray4D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 86b16be62f041ae3e96591627501592b34203e16..ee3468208792879c3fe4ff5860e434ef5a0c0155 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -24,6 +24,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", ], ) @@ -32,6 +33,7 @@ cc_library( srcs = ["testing.cc"], hdrs = ["testing.h"], deps = [ + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index a45974b86b67c14868fcfe9c5f8a43445a35807e..969b0eee1d195a36728f16a598add4b3b850ed60 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -22,65 +22,85 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { +namespace { +using InstructionGenerator = + ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, + const ComputationDataHandle&); + +Computation CreateScalarComputation(const string& name, PrimitiveType type, + ComputationBuilder* builder, + InstructionGenerator generator) { + std::unique_ptr b; + if (type == PRED) { + b = builder->CreateSubBuilder(name); + } else { + b = builder->CreateSubBuilder( + tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); + } -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder) { const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto b = builder->CreateSubBuilder("add_" + PrimitiveType_Name(type)); auto lhs = b->Parameter(0, scalar, "lhs"); auto rhs = b->Parameter(1, scalar, "rhs"); - b->Add(lhs, rhs); + generator(b.get(), lhs, rhs); return b->BuildAndNoteError(); } +} // namespace + +Computation CreateScalarAddComputation(PrimitiveType type, + ComputationBuilder* builder) { + return CreateScalarComputation( + "add", type, builder, + [](ComputationBuilder* b, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs) { return b->Add(lhs, rhs); }); +} + +Computation CreateScalarMultiplyComputation(PrimitiveType type, + ComputationBuilder* builder) { + return CreateScalarComputation( + "add", type, builder, + [](ComputationBuilder* b, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); +} Computation CreateScalarGeComputation(PrimitiveType type, ComputationBuilder* builder) { - const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto b = builder->CreateSubBuilder("ge_" + PrimitiveType_Name(type)); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - b->Ge(lhs, rhs); - return b->BuildAndNoteError(); + return CreateScalarComputation( + "ge", type, builder, + [](ComputationBuilder* b, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs) { return b->Ge(lhs, rhs); }); } Computation CreateScalarMaxComputation(PrimitiveType type, ComputationBuilder* builder) { - const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto b = builder->CreateSubBuilder("max_" + PrimitiveType_Name(type)); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - b->Max(lhs, rhs); - return b->BuildAndNoteError(); + return CreateScalarComputation( + "max", type, builder, + [](ComputationBuilder* b, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs) { return b->Max(lhs, rhs); }); } Computation CreateScalarMinComputation(PrimitiveType type, ComputationBuilder* builder) { - const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto b = builder->CreateSubBuilder("min_" + PrimitiveType_Name(type)); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(lhs, rhs); - return b->BuildAndNoteError(); + return CreateScalarComputation( + "min", type, builder, + [](ComputationBuilder* b, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); } Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder) { - const Shape scalar = ShapeUtil::MakeShape(PRED, {}); - auto b = builder->CreateSubBuilder("logical_and"); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - b->LogicalAnd(lhs, rhs); - return b->BuildAndNoteError(); + return CreateScalarComputation( + "logical_and", PRED, builder, + [](ComputationBuilder* b, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs) { return b->LogicalAnd(lhs, rhs); }); } Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder) { - const Shape scalar = ShapeUtil::MakeShape(PRED, {}); - auto b = builder->CreateSubBuilder("logical_or"); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - b->LogicalOr(lhs, rhs); - return b->BuildAndNoteError(); + return CreateScalarComputation( + "logical_or", PRED, builder, + [](ComputationBuilder* b, const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs) { return b->LogicalOr(lhs, rhs); }); } StatusOr Any(const ComputationDataHandle& predicates, diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 633086a2e7e4609543c465c9f52dc452ce3fabb3..f43d35fe4a52016d4054af28835d6b66a35217d4 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -28,6 +28,10 @@ namespace xla { Computation CreateScalarAddComputation(PrimitiveType type, ComputationBuilder* builder); +// Creates a scalar multiply computation and returns it. +Computation CreateScalarMultiplyComputation(PrimitiveType type, + ComputationBuilder* builder); + // Creates a scalar ge computation and returns it. Computation CreateScalarGeComputation(PrimitiveType type, ComputationBuilder* builder); diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index daa1557df0b97ee20679f45b8d54164ca93555fa..d8bfc945807d061234c1bc5999ea377a72e85a62 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,11 +35,11 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, client, tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); // TODO(b/26811613): Replace this when RNG is supported on all backends. - b.Broadcast(b.ConstantLiteral(LiteralUtil::One(shape.element_type())), + b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); Computation computation = b.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape; return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 67f3a6c1df4d74e5ef714dcaa56bae1e81f8276a..33d5b6f1d4d15d5143a3421c87eab9b7a7d11345 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { return execution_profile_; } +ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( + DeviceAssignment* device_assignment) { + device_assignment_ = device_assignment; + return *this; +} + +DeviceAssignment* ExecutableRunOptions::device_assignment() const { + return device_assignment_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 03f2d016ad07b63e6b7d9681c86885ce947f5319..deb3ddb203d263d25bef0499a8a53a6098d0de0c 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -40,6 +40,7 @@ struct ThreadPoolDevice; namespace xla { class DeviceMemoryAllocator; +class DeviceAssignment; class ExecutionProfile; // Class containing options for running a LocalExecutable. @@ -79,9 +80,14 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile() const; ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); + ExecutableRunOptions& set_device_assignment( + DeviceAssignment* device_assignment); + DeviceAssignment* device_assignment() const; + private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; + DeviceAssignment* device_assignment_ = nullptr; perftools::gputools::Stream* stream_ = nullptr; tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; diff --git a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html b/tensorflow/compiler/xla/execution_options_util.cc similarity index 50% rename from tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html rename to tensorflow/compiler/xla/execution_options_util.cc index a325f0a04cd033dd89b870a2fc6eca9a7a6f0020..e83ff7cddd675197c7f6d7018257edb4c25b6228 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html +++ b/tensorflow/compiler/xla/execution_options_util.cc @@ -1,6 +1,4 @@ - +==============================================================================*/ +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" - - +namespace xla { - - - - - +} // namespace xla diff --git a/tensorflow/compiler/xla/execution_options_util.h b/tensorflow/compiler/xla/execution_options_util.h new file mode 100644 index 0000000000000000000000000000000000000000..562da78e837ea6c4a01f0d1170797340fd421ad8 --- /dev/null +++ b/tensorflow/compiler/xla/execution_options_util.h @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ + +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla { + +// Create a default ExecutionOptions proto; this proto has its debug options +// popupated to the default values taken from flags. +ExecutionOptions CreateDefaultExecutionOptions(); + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 119c4e373f7c52993f6dbbdfe1554d818746ed1d..35a563bf22701b50c6bfed9193f8b17ffcb1ca90 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,35 +38,17 @@ limitations under the License. namespace xla { namespace { -using DimensionOrder = legacy_flags::DefaultLayout::DimensionOrder; - // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets // minor_to_major to the value that represents the default layout. void SetDefaultLayoutToContainer( tensorflow::protobuf::RepeatedField* minor_to_major) { + // The default XLA layout is major-to-minor (dim 0 is major). + // For more information on XLA layouts, see: + // https://www.tensorflow.org/performance/xla/shapes const int64 size = minor_to_major->size(); - legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags(); - auto default_layout = flags->xla_default_layout; - switch (default_layout.dimension_order) { - case DimensionOrder::kMajorToMinor: - for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, size - 1 - i); - } - break; - case DimensionOrder::kMinorToMajor: - for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, i); - } - break; - case DimensionOrder::kRandom: - for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, i); - } - std::shuffle( - minor_to_major->begin(), minor_to_major->end(), - std::mt19937(default_layout.seed != 0 ? default_layout.seed - : std::random_device()())); + for (int64 i = 0; i < size; ++i) { + minor_to_major->Set(i, size - 1 - i); } } diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index d3fcccff654fbbafa0b3c6a3d900123691f059fb..331bb9afa94e9e7c97d9c880dbac31c60ac0da18 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -210,13 +209,6 @@ TEST_F(LayoutUtilTest, IsPadded) { } TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { - // Test that LayoutUtil returns expected layouts when the xla_default_layout - // flag is set to kMajorToMinor. - legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags(); - flags->xla_default_layout = xla::legacy_flags::DefaultLayout{ - .dimension_order = - legacy_flags::DefaultLayout::DimensionOrder::kMajorToMinor}; - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), LayoutUtil::GetDefaultLayoutForR2())); EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({2, 1, 0}), @@ -229,25 +221,5 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } -TEST_F(LayoutUtilTest, DefaultLayoutGettersMinorToMajor) { - // Test that LayoutUtil returns expected layouts when the xla_default_layout - // flag is set to kMinorToMajor. - legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags(); - flags->xla_default_layout = xla::legacy_flags::DefaultLayout{ - .dimension_order = - legacy_flags::DefaultLayout::DimensionOrder::kMinorToMajor}; - - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), - LayoutUtil::GetDefaultLayoutForR2())); - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), - LayoutUtil::GetDefaultLayoutForR3())); - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2, 3}), - LayoutUtil::GetDefaultLayoutForR4())); - EXPECT_TRUE( - LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2, 3, 4}), - LayoutUtil::GetDefaultLayoutForShape( - ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index a147ce67a28884d485280b4d811875d569fad879..b47c82f075a1b71dd355bd86ae7200360ab0f388 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -41,31 +41,6 @@ cc_test( ], ) -cc_library( - name = "layout_util_flags", - srcs = ["layout_util_flags.cc"], - hdrs = ["layout_util_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "util_flags", - srcs = ["util_flags.cc"], - hdrs = ["util_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "debug_options_flags", srcs = ["debug_options_flags.cc"], @@ -73,188 +48,12 @@ cc_library( deps = [ ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], ) -cc_library( - name = "cpu_compiler_flags", - srcs = ["cpu_compiler_flags.cc"], - hdrs = ["cpu_compiler_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "cpu_runtime_flags", - srcs = ["cpu_runtime_flags.cc"], - hdrs = ["cpu_runtime_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "compiler_functor_flags", - srcs = ["compiler_functor_flags.cc"], - hdrs = ["compiler_functor_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "convolution_thunk_flags", - srcs = ["convolution_thunk_flags.cc"], - hdrs = ["convolution_thunk_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "gpu_compiler_flags", - srcs = ["gpu_compiler_flags.cc"], - hdrs = ["gpu_compiler_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "gpu_backend_lib_flags", - srcs = ["gpu_backend_lib_flags.cc"], - hdrs = ["gpu_backend_lib_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "stream_assignment_flags", - srcs = ["stream_assignment_flags.cc"], - hdrs = ["stream_assignment_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "hlo_graph_dumper_flags", - srcs = ["hlo_graph_dumper_flags.cc"], - hdrs = ["hlo_graph_dumper_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "alias_analysis_flags", - srcs = ["alias_analysis_flags.cc"], - hdrs = ["alias_analysis_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "llvm_util_flags", - srcs = ["llvm_util_flags.cc"], - hdrs = ["llvm_util_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "service_flags", - srcs = ["service_flags.cc"], - hdrs = ["service_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "buffer_assignment_flags", - srcs = ["buffer_assignment_flags.cc"], - hdrs = ["buffer_assignment_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "hlo_test_base_flags", - srcs = ["hlo_test_base_flags.cc"], - hdrs = ["hlo_test_base_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "backend_flags", - srcs = ["backend_flags.cc"], - hdrs = ["backend_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "user_computation_flags", - srcs = ["user_computation_flags.cc"], - hdrs = ["user_computation_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc deleted file mode 100644 index 474753c10ad7ed5eb4a9a446c3f877280c5ad302..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's alias_analysis module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static AliasAnalysisFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new AliasAnalysisFlags; - flags->xla_emit_alias_scope = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_emit_alias_scope", &flags->xla_emit_alias_scope, - "Use buffer analysis to refine alias-analysis."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's alias_analysis -// module. -void AppendAliasAnalysisFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the AliasAnalysisFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -AliasAnalysisFlags* GetAliasAnalysisFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h deleted file mode 100644 index 369f8cd7caa6f42273cd405ca5f43d325e457128..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ - -// Legacy flags for XLA's alias_analysis module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's alias_analysis -// module. -void AppendAliasAnalysisFlags(std::vector* flag_list); - -// The values of flags associated with XLA's alias_analysis module. -typedef struct { - bool xla_emit_alias_scope; // Use buffer analysis to refine alias-analysis. -} AliasAnalysisFlags; - -// Return a pointer to the AliasAnalysisFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -AliasAnalysisFlags* GetAliasAnalysisFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/backend_flags.cc b/tensorflow/compiler/xla/legacy_flags/backend_flags.cc deleted file mode 100644 index 7c007f4435c088b35bffce40372f88f37af6ed5b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/backend_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's backend module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static BackendFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new BackendFlags; - // TODO(b/32648682): Decide if this should continue to be a flag longer term. - flags->xla_replicas = 1; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_replicas", &flags->xla_replicas, - "The number of replicas to use. 1 means no replication."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's backend module. -void AppendBackendFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the BackendFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BackendFlags* GetBackendFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/backend_flags.h b/tensorflow/compiler/xla/legacy_flags/backend_flags.h deleted file mode 100644 index 061238b7e690257f4eb681558dcd59b1f8ba2653..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/backend_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_ - -// Legacy flags for XLA's backend module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's backend module. -void AppendBackendFlags(std::vector* flag_list); - -// The values of flags associated with XLA's backend module. -typedef struct { - int64 xla_replicas; // The number of replicas to use. 1 means no - // replication. -} BackendFlags; - -// Return a pointer to the BackendFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BackendFlags* GetBackendFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc deleted file mode 100644 index 71873f73afd5bb8c59832a4c82f87f4e51c31180..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's buffer_assignment module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static BufferAssignmentFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new BufferAssignmentFlags; - flags->xla_enable_buffer_reuse = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_enable_buffer_reuse", - &flags->xla_enable_buffer_reuse, - "Enable reuse of buffers."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's buffer_assignment -// module. -void AppendBufferAssignmentFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the BufferAssignmentFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BufferAssignmentFlags* GetBufferAssignmentFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h deleted file mode 100644 index 5f098c2663f638940aead45b74332edcf3fcc37f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ - -// Legacy flags for XLA's buffer_assignment module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's buffer_assignment -// module. -void AppendBufferAssignmentFlags(std::vector* flag_list); - -// The values of flags associated with XLA's buffer_assignment module. -typedef struct { - bool xla_enable_buffer_reuse; // Enable reuse of buffers. -} BufferAssignmentFlags; - -// Return a pointer to the BufferAssignmentFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BufferAssignmentFlags* GetBufferAssignmentFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc deleted file mode 100644 index 617a9b712ed99d343dc28b6e6c0de4b54e271096..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's compiler_functor module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static CompilerFunctorFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new CompilerFunctorFlags; - flag_list = new std::vector({ - tensorflow::Flag("xla_debug_cpu_dump_ir", &flags->xla_debug_cpu_dump_ir, - "Dump IR, before optimizations to a path"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's compiler_functor -// module. -void AppendCompilerFunctorFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the CompilerFunctorFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CompilerFunctorFlags* GetCompilerFunctorFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h deleted file mode 100644 index 28b505ec5eac2d74879a22779137c6982a7c9ce8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ - -// Legacy flags for the XLA's compiler_functor module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's compiler_functor -// module. -void AppendCompilerFunctorFlags(std::vector* flag_list); - -// The values of flags associated with XLA's compiler_functor module. -typedef struct { - string xla_debug_cpu_dump_ir; // Dump IR, before optimizations to a path -} CompilerFunctorFlags; - -// Return a pointer to the CompilerFunctorFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CompilerFunctorFlags* GetCompilerFunctorFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc deleted file mode 100644 index fe5d19147f09557817fee5c670f52058f21f5cdc..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's convolution_thunk module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static ConvolutionThunkFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new ConvolutionThunkFlags; - flags->xla_gpu_autotune_convolution_algorithm = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_gpu_autotune_convolution_algorithm", - &flags->xla_gpu_autotune_convolution_algorithm, - "Auto-tune the algorithm used by convolution"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's convolution_thunk -// module. -void AppendConvolutionThunkFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the ConvolutionThunkFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ConvolutionThunkFlags* GetConvolutionThunkFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h deleted file mode 100644 index 53d6806a71902d1227728f74bd45f12f9d11421d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ - -// Legacy flags for XLA's convolution_thunk module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's convolution_thunk -// module. -void AppendConvolutionThunkFlags(std::vector* flag_list); - -// The values of flags associated with XLA's convolution_thunk module. -typedef struct { - // Auto-tune the algorithm used by convolution - bool xla_gpu_autotune_convolution_algorithm; -} ConvolutionThunkFlags; - -// Return a pointer to the ConvolutionThunkFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ConvolutionThunkFlags* GetConvolutionThunkFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc deleted file mode 100644 index 13d41a8636b6ba3aa88545523e93dffe4b0c12f5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's cpu_compiler module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static CpuCompilerFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new CpuCompilerFlags; - flags->xla_cpu_embed_ir = false; - flags->xla_cpu_dump_debug_json_to = ""; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_cpu_embed_ir", &flags->xla_cpu_embed_ir, - "Embed the LLVM IR module string in the resultant CpuExecutable."), - tensorflow::Flag("xla_cpu_dump_debug_json_to", - &flags->xla_cpu_dump_debug_json_to, - "Dump debug JSON to this directory."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's cpu_compiler -// module. -void AppendCpuCompilerFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the CpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuCompilerFlags* GetCpuCompilerFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h deleted file mode 100644 index bac498e18eb241d3b3044f14c88ac2b3aaaa322f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ - -// Legacy flags for the XLA's cpu_compiler module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's cpu_compiler -// module. -void AppendCpuCompilerFlags(std::vector* flag_list); - -// The values of flags associated with XLA's cpu_compiler module. -typedef struct { - bool xla_cpu_embed_ir; // Embed the LLVM IR module string in the resultant - // CpuExecutable - string xla_cpu_dump_debug_json_to; // Dump debug JSON to this directory. -} CpuCompilerFlags; - -// Return a pointer to the CpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuCompilerFlags* GetCpuCompilerFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc deleted file mode 100644 index d7817c5d54a047b1987a19dfbde9f48081ae6413..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's cpu_runtime module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static CpuRuntimeFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new CpuRuntimeFlags; - flags->xla_cpu_use_eigen = true; - flags->xla_cpu_multi_thread_eigen = true; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_cpu_use_eigen", &flags->xla_cpu_use_eigen, - "Use Eigen for matrix multiply on the CPU platform. This " - "is a useful hack for performance comparisons against " - "XLA's implementation."), - tensorflow::Flag( - "xla_cpu_multi_thread_eigen", &flags->xla_cpu_multi_thread_eigen, - "When generating calls to Eigen for matmul and conv, should " - "single or multi-threaded eigen be used? " - "Only used when --xla_cpu_use_eigen is true."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's cpu_runtime -// module. -void AppendCpuRuntimeFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the CpuRuntimeFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuRuntimeFlags* GetCpuRuntimeFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h deleted file mode 100644 index e3ff30da36a5fabd7d7798fd636cb3955a91b09f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ - -// Legacy flags for the XLA's cpu_runtime module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's cpu_runtime -// module. -void AppendCpuRuntimeFlags(std::vector* flag_list); - -// The values of flags associated with XLA's cpu_runtime module. -typedef struct { - // Use Eigen for matrix multiply on the CPU platform. This is a useful hack - // for performance comparisons against XLA's implementation. - bool xla_cpu_use_eigen; - // When generating calls to Eigen for matmul and conv, should single or - // multi-threaded eigen be used? Only used when --xla_cpu_use_eigen is true. - bool xla_cpu_multi_thread_eigen; -} CpuRuntimeFlags; - -// Return a pointer to the CpuRuntimeFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuRuntimeFlags* GetCpuRuntimeFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 5e3c4f912bf6073e89a66633c44a7e052ca43ade..87c6215e6badc9f7e4c99f78fb23c8d621b9dbd2 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -23,56 +23,233 @@ limitations under the License. namespace xla { namespace legacy_flags { -struct DebugOptionsFlags { - string xla_generate_hlo_graph; - string xla_disable_hlo_passes; - bool xla_enable_fast_math; - int32 xla_backend_optimization_level; - string xla_backend_extra_options; -}; - namespace { -DebugOptionsFlags* flag_values; +DebugOptions* flag_values; std::vector* flag_objects; std::once_flag flags_init; +namespace { +void SetDebugOptionsDefaults(DebugOptions* flags) { + flags->set_xla_hlo_graph_path("/tmp/"); + flags->set_xla_enable_fast_math(true); + flags->set_xla_llvm_enable_alias_scope_metadata(true); + flags->set_xla_llvm_enable_noalias_metadata(true); + flags->set_xla_llvm_enable_invariant_load_metadata(true); + flags->set_xla_backend_optimization_level(3); + flags->set_xla_cpu_multi_thread_eigen(true); + flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); +} +} // namespace + // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. void AllocateFlags() { - flag_values = new DebugOptionsFlags; - flag_values->xla_generate_hlo_graph = ""; - flag_values->xla_disable_hlo_passes = ""; - flag_values->xla_enable_fast_math = true; - flag_values->xla_backend_optimization_level = 2; - flag_values->xla_backend_extra_options = ""; + flag_values = new DebugOptions; + + SetDebugOptionsDefaults(flag_values); + + // Returns a lambda that calls "member_setter" on "flag_values" with the + // argument passed in to the lambda. + auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) { + return [member_setter](bool value) { + (flag_values->*member_setter)(value); + return true; + }; + }; + + // Returns a lambda that calls "member_setter" on "flag_values" with the + // argument passed in to the lambda. + auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32)) { + return [member_setter](int32 value) { + (flag_values->*member_setter)(value); + return true; + }; + }; + + // Returns a lambda that is a custom "sub-parser" for xla_disable_hlo_passes. + auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { + std::vector disabled_passes = + tensorflow::str_util::Split(comma_separated_values, ','); + for (const auto& passname : disabled_passes) { + flag_values->add_xla_disable_hlo_passes(passname); + } + return true; + }; + + // Returns a lambda that is a custom "sub-parser" for + // xla_backend_extra_options. + auto setter_for_xla_backend_extra_options = + [](string comma_separated_values) { + std::vector extra_options_parts = + tensorflow::str_util::Split(comma_separated_values, ','); + auto* extra_options_map = + flag_values->mutable_xla_backend_extra_options(); + + // The flag contains a comma-separated list of options; some options + // have arguments following "=", some don't. + for (const auto& part : extra_options_parts) { + size_t eq_pos = part.find_first_of('='); + if (eq_pos == string::npos) { + (*extra_options_map)[part] = ""; + } else { + string value = ""; + if (eq_pos + 1 < part.size()) { + value = part.substr(eq_pos + 1); + } + (*extra_options_map)[part.substr(0, eq_pos)] = value; + } + } + + return true; + }; flag_objects = new std::vector( {tensorflow::Flag( - "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph, + "xla_generate_hlo_graph", + flag_values->mutable_xla_generate_hlo_graph(), "HLO modules matching this regex will be dumped to a .dot file " "throughout various stages in compilation."), - tensorflow::Flag( - "xla_enable_fast_math", &flag_values->xla_enable_fast_math, + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), + "With xla_generate_hlo_graph, show addresses of HLO ops in " + "graph dump."), + tensorflow::Flag( + "xla_hlo_graph_layout", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_layout), + flag_values->xla_hlo_graph_layout(), + "With xla_generate_hlo_graph, show layout of HLO ops in " + "graph dump."), + tensorflow::Flag( + "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), + "With xla_generate_hlo_graph, dump the graphs into this path."), + tensorflow::Flag( + "xla_hlo_dump_as_graphdef", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), + flag_values->xla_hlo_dump_as_graphdef(), + "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), + "HLO modules matching this regex will be dumped to LOG(INFO). "), + tensorflow::Flag( + "xla_generate_hlo_text_to", + flag_values->mutable_xla_generate_hlo_text_to(), + "Dump all HLO modules as text into the provided directory path."), + tensorflow::Flag( + "xla_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_enable_fast_math), + flag_values->xla_enable_fast_math(), "Enable unsafe fast-math optimizations in the compiler; " "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_llvm_enable_alias_scope_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), + flag_values->xla_llvm_enable_alias_scope_metadata(), + "In LLVM-based backends, enable the emission of " + "!alias.scope metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_noalias_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), + flag_values->xla_llvm_enable_noalias_metadata(), + "In LLVM-based backends, enable the emission of " + "!noalias metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_invariant_load_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), + flag_values->xla_llvm_enable_invariant_load_metadata(), + "In LLVM-based backends, enable the emission of " + "!invariant.load metadata in " + "the generated IR."), tensorflow::Flag( "xla_backend_optimization_level", - &flag_values->xla_backend_optimization_level, + int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), + flag_values->xla_backend_optimization_level(), "Numerical optimization level for the XLA compiler backend."), - + tensorflow::Flag( + "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", + "Comma-separated list of hlo passes to be disabled. These names " + "must exactly match the passes' names; no whitespace around " + "commas."), + tensorflow::Flag( + "xla_embed_ir_in_executable", + bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), + flag_values->xla_embed_ir_in_executable(), + "Embed the compiler IR as a string in the executable."), + tensorflow::Flag( + "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), + "Dump the compiler IR into this directory as individual files."), + tensorflow::Flag( + "xla_eliminate_hlo_implicit_broadcast", + bool_setter_for( + &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), + flag_values->xla_eliminate_hlo_implicit_broadcast(), + "Eliminate implicit broadcasts when lowering user " + "computations to HLO instructions; use explicit " + "broadcast instead."), + tensorflow::Flag( + "xla_cpu_multi_thread_eigen", + bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), + flag_values->xla_cpu_multi_thread_eigen(), + "When generating calls to Eigen in the CPU backend, " + "use multi-threaded Eigen mode."), + tensorflow::Flag("xla_gpu_cuda_data_dir", + flag_values->mutable_xla_gpu_cuda_data_dir(), + "If non-empty, speficies a local directory containing " + "ptxas and nvvm libdevice files; otherwise we use " + "those from runfile directories."), + tensorflow::Flag("xla_gpu_ftz", + bool_setter_for(&DebugOptions::set_xla_gpu_ftz), + flag_values->xla_gpu_ftz(), + "If true, flush-to-zero semantics are enabled in the " + "code generated for GPUs."), + tensorflow::Flag( + "xla_gpu_disable_multi_streaming", + bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), + flag_values->xla_gpu_disable_multi_streaming(), + "If true, multi-streaming in the GPU backend is disabled."), + tensorflow::Flag( + "xla_dump_debug_json_to", + flag_values->mutable_xla_dump_debug_json_to(), + "Dump compilation artifacts as JSON into this directory."), + tensorflow::Flag( + "xla_test_all_output_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), + flag_values->xla_test_all_output_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of output layouts. For example, with " + "a 3D shape, all permutations of the set {0, 1, 2} are " + "tried."), + tensorflow::Flag( + "xla_test_all_input_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), + flag_values->xla_test_all_input_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of *input* layouts. For example, for " + "2 input arguments with 2D shape and 4D shape, the " + "computation will run 2! * 4! times for every possible " + "layouts"), + tensorflow::Flag( + "xla_hlo_profile", + bool_setter_for(&DebugOptions::set_xla_hlo_profile), + flag_values->xla_hlo_profile(), + "Instrument the computation to collect per-HLO cycle counts"), + tensorflow::Flag("xla_dump_computations_to", + flag_values->mutable_xla_dump_computations_to(), + "Dump computations that XLA executes into the provided " + "directory path"), + tensorflow::Flag("xla_dump_executions_to", + flag_values->mutable_xla_dump_executions_to(), + "Dump parameters and results of computations that XLA " + "executes into the provided directory path"), tensorflow::Flag("xla_backend_extra_options", - &flag_values->xla_backend_extra_options, + setter_for_xla_backend_extra_options, "", "Extra options to pass to a backend; " "comma-separated list of 'key=val' strings (=val " - "may be omitted); no whitespace around commas."), - - tensorflow::Flag( - "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, - "Comma-separated list of HLO passes to be disabled. These names " - "must exactly match the passes' names; " - "no whitespace around commas.")}); + "may be omitted); no whitespace around commas.")}); ParseFlagsFromEnv(*flag_objects); } @@ -86,40 +263,7 @@ void AppendDebugOptionsFlags(std::vector* flag_list) { xla::DebugOptions GetDebugOptionsFromFlags() { std::call_once(flags_init, &AllocateFlags); - - DebugOptions options; - options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph); - - std::vector disabled_passes = - tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ','); - for (const auto& passname : disabled_passes) { - options.add_xla_disable_hlo_passes(passname); - } - - options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math); - options.set_xla_backend_optimization_level( - flag_values->xla_backend_optimization_level); - - std::vector extra_options_parts = - tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ','); - auto* extra_options_map = options.mutable_xla_backend_extra_options(); - - // The flag contains a comma-separated list of options; some options have - // arguments following "=", some don't. - for (const auto& part : extra_options_parts) { - size_t eq_pos = part.find_first_of('='); - if (eq_pos == string::npos) { - (*extra_options_map)[part] = ""; - } else { - string value = ""; - if (eq_pos + 1 < part.size()) { - value = part.substr(eq_pos + 1); - } - (*extra_options_map)[part.substr(0, eq_pos)] = value; - } - } - - return options; + return *flag_values; } } // namespace legacy_flags diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc deleted file mode 100644 index f8f6ea26b1d0df67b934616fe60aa29199fc2eb9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's gpu_backend_lib module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static GpuBackendLibFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new GpuBackendLibFlags; - flags->dump_temp_products_to = ""; - flags->ftz = false; - flags->fma = true; - flags->verbose_ptx_asm = false; - flags->kernel = ""; - flags->llvm_dump_passes = false; - flags->llvm_cl_opts = ""; - flags->dump_ir_before_passes = false; - flags->opt_level = 3; - flag_list = new std::vector({ - tensorflow::Flag("dump_temp_products_to", &flags->dump_temp_products_to, - "dump temporary compilation products to this directory. " - "If empty, no dump is produced"), - tensorflow::Flag("ftz", &flags->ftz, "flush to zero semantics"), - tensorflow::Flag("fma", &flags->fma, "use FMA synthesis"), - tensorflow::Flag("verbose_ptx_asm", &flags->verbose_ptx_asm, - "emit PTX assembly with extra comments"), - tensorflow::Flag("kernel", &flags->kernel, - "only emit the IR and PTX for this kernel"), - tensorflow::Flag("llvm_dump_passes", &flags->llvm_dump_passes, - "dump the passes LLVM runs to stderr"), - tensorflow::Flag( - "llvm_cl_opts", &flags->llvm_cl_opts, - "comma-separated list of command line options to pass to " - "LLVM. For example, --llvm_cl_opts=--print-before=loop-unroll"), - tensorflow::Flag("dump_ir_before_passes", &flags->dump_ir_before_passes, - "dump the IR before each optimization pass in " - "sequentially-named files."), - tensorflow::Flag("opt_level", &flags->opt_level, - "optimization level (default to 3)"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's gpu_backend_lib -// module. -void AppendGpuBackendLibFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the GpuBackendLibFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuBackendLibFlags* GetGpuBackendLibFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h deleted file mode 100644 index 31cb50e9da986b5bad3e71439a4976ec84e17be7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ - -// Legacy flags for XLA's gpu_backend_lib module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's gpu_backend_lib -// module. -void AppendGpuBackendLibFlags(std::vector* flag_list); - -// The values of flags associated with XLA's gpu_backend_lib module. -typedef struct { - string dump_temp_products_to; // temporary compilation products dir - bool ftz; // flush to zero semantics - bool fma; // use FMA synthesis - bool verbose_ptx_asm; // emit PTX assembly with extra comments - string kernel; // only emit the IR and PTX for this kernel - bool llvm_dump_passes; // dump the passes LLVM runs to stderr - string llvm_cl_opts; // comma-separated list of LLVM options - bool dump_ir_before_passes; // dump IR before each pass - int32 opt_level; // optimization level -} GpuBackendLibFlags; - -// Return a pointer to the GpuBackendLibFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuBackendLibFlags* GetGpuBackendLibFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc deleted file mode 100644 index 131e3ce70ac9e7fc2f6f233ffd93e8757d0bc725..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's gpu_compiler module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static GpuCompilerFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new GpuCompilerFlags; - flags->xla_gpu_embed_ir = false; - flags->xla_cuda_data_dir = "./cuda_sdk_lib"; - flags->xla_gpu_dump_debug_json_to = ""; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir, - "Embed the LLVM IR module string in the resultant GpuExecutable."), - tensorflow::Flag( - "xla_cuda_data_dir", &flags->xla_cuda_data_dir, - "If non-empty, specifies a local directory containing ptxas and " - "nvvm libdevice files. Otherwise, by default, we use those from " - "runfile directories."), - tensorflow::Flag("xla_ptxas_path", &flags->xla_ptxas_path, - "The path to ptxas. Required to log stats of the ptx."), - tensorflow::Flag("xla_gpu_dump_debug_json_to", - &flags->xla_gpu_dump_debug_json_to, - "Dump debug JSON to this directory."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's gpu_compiler -// module. -void AppendGpuCompilerFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the GpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuCompilerFlags* GetGpuCompilerFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h deleted file mode 100644 index 0cf39e0ab35e663c7abc14980daa8b92d15489d6..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ - -// Legacy flags for XLA's gpu_compiler module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's gpu_compiler -// module. -void AppendGpuCompilerFlags(std::vector* flag_list); - -// The values of flags associated with XLA's gpu_compiler module. -typedef struct { - bool xla_gpu_embed_ir; // Embed the LLVM IR module string in the resultant - // GpuExecutable. - string xla_cuda_data_dir; // If non-empty, specifies a local directory - // containing ptxas and nvvm libdevice files. - // Otherwise, by default, we use those from runfile - // directories. - string xla_ptxas_path; // The path to ptxas. Required to log stats of - // the ptx. - string xla_gpu_dump_debug_json_to; // Dump debug JSON to this directory. -} GpuCompilerFlags; - -// Return a pointer to the GpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuCompilerFlags* GetGpuCompilerFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc deleted file mode 100644 index ba43a5919522ff783f450481c629d64613e1f8ab..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's hlo_graph_dumper module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static HloGraphDumperFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new HloGraphDumperFlags; - flags->xla_hlo_dump_graph_path = "/tmp/"; - flags->xla_hlo_dump_as_graphdef = false; - flag_list = new std::vector({ - tensorflow::Flag("xla_hlo_dump_graph_path", - &flags->xla_hlo_dump_graph_path, - "Path to write dumped HLO graphs to"), - tensorflow::Flag("xla_hlo_dump_as_graphdef", - &flags->xla_hlo_dump_as_graphdef, - "Dumps HLO graphs as tensorflow GraphDefs"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_graph_dumper -// module. -void AppendHloGraphDumperFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the HloGraphDumperFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloGraphDumperFlags* GetHloGraphDumperFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h deleted file mode 100644 index d0b4d092ff1003bc1df90c3d878feacf71a5aa21..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ - -// Legacy flags for XLA's hlo_graph_dumper module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's hlo_graph_dumper -// module. -void AppendHloGraphDumperFlags(std::vector* flag_list); - -// The values of flags associated with XLA's hlo_graph_dumper module. -typedef struct { - string xla_hlo_dump_graph_path; // Path to write dumped HLO graphs to - // If set, dumps HLO graphs as tensorflow GraphDef; otherwise, dumps HLO - // graphs as DOT graph. - bool xla_hlo_dump_as_graphdef; -} HloGraphDumperFlags; - -// Return a pointer to the HloGraphDumperFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloGraphDumperFlags* GetHloGraphDumperFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc deleted file mode 100644 index c7893c138596b034dbb83df9fda2d4c5edd8e32b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's hlo_test_base module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static HloTestBaseFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new HloTestBaseFlags; - flags->xla_hlo_test_generate_hlo_graph = false; - flag_list = new std::vector({ - tensorflow::Flag("xla_hlo_test_generate_hlo_graph", - &flags->xla_hlo_test_generate_hlo_graph, - "Generate graph output of HLO instructions"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_test_base -// module. -void AppendHloTestBaseFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the HloTestBaseFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloTestBaseFlags* GetHloTestBaseFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h deleted file mode 100644 index 23b808cecb7e5eaf480292f5207a4b87ebd4a2d5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ - -// Legacy flags for XLA's hlo_test_base module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's hlo_test_base -// module. -void AppendHloTestBaseFlags(std::vector* flag_list); - -// The values of flags associated with XLA's hlo_test_base module. -typedef struct { - bool xla_hlo_test_generate_hlo_graph; // Generate graph output of HLO - // instructions -} HloTestBaseFlags; - -// Return a pointer to the HloTestBaseFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloTestBaseFlags* GetHloTestBaseFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc deleted file mode 100644 index f838861898ddd08b56a13f9b8f722f3c1e4da5eb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's layout_util module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the string value of the xla_default_layout flag and the flag -// descriptor, initialized via raw_flags_init. -static string* raw_flag; -static std::vector* flag_list; -static std::once_flag raw_flags_init; - -// Allocate *raw_flag. Called via call_once(&raw_flags_init,...). -static void AllocateRawFlag() { - raw_flag = new string; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_default_layout", raw_flag, - "Default layout for Shapes in XLA. Valid values are: " - "'minor2major', 'major2minor', 'random', 'random:'. " - "For debugging purposes. If no seed (or 0) is given, a seed from " - "random_device is used."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Parse text into *layout. -static bool ParseDefaultLayout(const string& text, DefaultLayout* layout) { - bool result = true; - std::vector field = tensorflow::str_util::Split(text, ':'); - if (!field.empty()) { - if (field[0] == "random") { - layout->dimension_order = DefaultLayout::DimensionOrder::kRandom; - if (field.size() > 1) { - uint64 seed = 0; - result = tensorflow::strings::safe_strtou64(field[1], &seed); - layout->seed = seed; - } - } else if (field[0] == "minor2major") { - layout->dimension_order = DefaultLayout::DimensionOrder::kMinorToMajor; - } else if (field[0] == "major2minor") { - layout->dimension_order = DefaultLayout::DimensionOrder::kMajorToMinor; - } else { - result = false; - } - } - return result; -} - -// Pointer to the parsed value of the flags, initialized via flags_init. -static LayoutUtilFlags* flags; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - std::call_once(raw_flags_init, &AllocateRawFlag); - flags = new LayoutUtilFlags; - flags->xla_default_layout.dimension_order = - DefaultLayout::DimensionOrder::kMajorToMinor; - flags->xla_default_layout.seed = 0; - if (!ParseDefaultLayout(*raw_flag, &flags->xla_default_layout)) { - flags = nullptr; - } -} - -// Append to *append_to the flag definitions associated with XLA's layout_util -// module. -void AppendLayoutUtilFlags(std::vector* append_to) { - std::call_once(raw_flags_init, &AllocateRawFlag); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the LayoutUtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -LayoutUtilFlags* GetLayoutUtilFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h deleted file mode 100644 index 177f428b734dcdf703472f3e240aef9792f988d7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_ - -// Legacy flags for the XLA's layout_util module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// The default layout for all newly created shapes. Specified by the flag -// --xla_default_layout. -struct DefaultLayout { - enum class DimensionOrder { - kRandom, - kMinorToMajor, - kMajorToMinor, - }; - - DimensionOrder dimension_order; - size_t seed; -}; - -// Append to *flag_list the flag definitions associated with XLA's layout_util -// module. -void AppendLayoutUtilFlags(std::vector* flag_list); - -// The values of flags associated with XLA's layout_util module. -typedef struct { - // Default layout for Shapes in XLA. Valid values are: 'minor2major', - // 'major2minor', 'random', 'random:'. For debugging purposes. If no - // seed (or 0) is given, a seed from random_device is used. - DefaultLayout xla_default_layout; -} LayoutUtilFlags; - -// Return a pointer to the LayoutFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -LayoutUtilFlags* GetLayoutUtilFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc deleted file mode 100644 index 3c53729a67049fdac6b358149e06f39858ebd98f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's llvm_util module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static LlvmUtilFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new LlvmUtilFlags; - flags->xla_emit_tbaa = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_emit_tbaa", &flags->xla_emit_tbaa, - "Perform type-based alias analysis optimizations for " - "LLVM-based backends."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's llvm_util -// module. -void AppendLlvmUtilFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the LlvmUtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -LlvmUtilFlags* GetLlvmUtilFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h deleted file mode 100644 index 98da26b4b806dd83c7baf6bdcf60cbf5297457a6..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ - -// Legacy flags for XLA's llvm_util module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's llvm_util module. -void AppendLlvmUtilFlags(std::vector* flag_list); - -// The values of flags associated with XLA's llvm_util module. -typedef struct { - bool xla_emit_tbaa; // Perform type-based alias analysis optimizations for - // LLVM-based backends. -} LlvmUtilFlags; - -// Return a pointer to the LlvmUtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -LlvmUtilFlags* GetLlvmUtilFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.cc b/tensorflow/compiler/xla/legacy_flags/service_flags.cc deleted file mode 100644 index 41cb8d8bdfc51de1d8fe77906317b4b4a0804802..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/service_flags.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's service module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static ServiceFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new ServiceFlags; - flags->xla_hlo_profile = false; - flags->xla_log_hlo_text = ""; - flags->xla_generate_hlo_graph = ""; - flags->xla_hlo_graph_addresses = false; - flags->xla_hlo_graph_layout = false; - flags->xla_hlo_graph_for_compute_constant = false; - flags->xla_dump_computations_to = ""; - flags->xla_dump_hlo_text_to = ""; - flags->xla_dump_executions_to = ""; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_hlo_profile", &flags->xla_hlo_profile, - "Instrument the computation to collect per-HLO cycle counts"), - tensorflow::Flag( - "xla_log_hlo_text", &flags->xla_log_hlo_text, - "If non-empty, print the text format of " - "HLO modules whose name partially matches this regex. E.g. " - "xla_log_hlo_text=.* will dump the text for every module."), - tensorflow::Flag( - "xla_generate_hlo_graph", &flags->xla_generate_hlo_graph, - "If non-empty, dump graph of HLO modules whose name partially " - "matches this regex. E.g. --xla_generate_hlo_graph=.* will dump " - "the graph of every module."), - tensorflow::Flag("xla_hlo_graph_addresses", - &flags->xla_hlo_graph_addresses, - "Show addresses of HLO ops in graph"), - tensorflow::Flag("xla_hlo_graph_layout", &flags->xla_hlo_graph_layout, - "Show layout of HLO ops in graph"), - tensorflow::Flag( - "xla_hlo_graph_for_compute_constant", - &flags->xla_hlo_graph_for_compute_constant, - "If true, include hlo dumps of graphs from ComputeConstant." - "Such graphs still need to be matched via xla_generate_hlo_graph."), - tensorflow::Flag("xla_dump_computations_to", - &flags->xla_dump_computations_to, - "Dumps computations that XLA executes into the provided " - "directory path"), - tensorflow::Flag("xla_dump_hlo_text_to", &flags->xla_dump_hlo_text_to, - "Dumps HLO modules that XLA executes into the provided " - "directory path"), - tensorflow::Flag("xla_dump_executions_to", &flags->xla_dump_executions_to, - "Dumps parameters and results of computations that XLA " - "executes into the provided directory path"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's service module. -void AppendServiceFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the ServiceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ServiceFlags* GetServiceFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.h b/tensorflow/compiler/xla/legacy_flags/service_flags.h deleted file mode 100644 index d982506944daed41eb6e7c4a238d540b38cf8be3..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/service_flags.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_ - -// Legacy flags for XLA's service module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's service module. -void AppendServiceFlags(std::vector* flag_list); - -// The values of flags associated with XLA's service module. -typedef struct { - bool xla_hlo_profile; // Instrument the computation to collect per-HLO cycle - // counts - string xla_log_hlo_text; // If non-empty, print the text format of the HLO - // modules whose name partially - // matches this regex. E.g. xla_log_hlo_text=.* - // will dump the text for every module. - string xla_generate_hlo_graph; // If non-empty, dump graph of HLO modules - // whose name partially matches this regex. - // E.g. --xla_generate_hlo_graph=.* will dump - // the graph of every module. - bool xla_hlo_graph_addresses; // Show addresses of HLO ops in graph - bool xla_hlo_graph_layout; // Show layout of HLO ops in graph - bool xla_hlo_graph_for_compute_constant; // If true, include hlo dumps of - // graphs from ComputeConstant. - // Such graphs still need to be - // matched via - // xla_generate_hlo_graph. - string xla_dump_hlo_text_to; // Dumps HLO text for each HLO module that is - // executed into the provided directory path - string xla_dump_computations_to; // Dumps computations that XLA executes - // into the provided directory path - // Dumps parameters and results of computations that XLA executes into - // the provided directory path - string xla_dump_executions_to; -} ServiceFlags; - -// Return a pointer to the ServiceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ServiceFlags* GetServiceFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc deleted file mode 100644 index 6506175777ccd262b6467f8fbe6de8bb24eff945..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's stream_assignment module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static StreamAssignmentFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new StreamAssignmentFlags; - flags->xla_gpu_disable_multi_streaming = false; - flag_list = new std::vector({ - tensorflow::Flag("xla_gpu_disable_multi_streaming", - &flags->xla_gpu_disable_multi_streaming, - "Disable multi-streaming in XLA's GPU backend"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's stream_assignment -// module. -void AppendStreamAssignmentFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the StreamAssignmentFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -StreamAssignmentFlags* GetStreamAssignmentFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h deleted file mode 100644 index a98f9b34584b43161aa8e3248c28d520403f3f3a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_ - -// Legacy flags for XLA's stream_assignment module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's stream_assignment -// module. -void AppendStreamAssignmentFlags(std::vector* flag_list); - -// The values of flags associated with XLA's stream_assignment module. -typedef struct { - bool xla_gpu_disable_multi_streaming; // Disable multi-streaming in XLA's GPU - // backend -} StreamAssignmentFlags; - -// Return a pointer to the StreamAssignmentFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -StreamAssignmentFlags* GetStreamAssignmentFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc deleted file mode 100644 index a9597d0cd8f89d7d664c38b79d225b0aa6b6b13b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static UserComputationFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new UserComputationFlags; - flags->xla_eliminate_hlo_implicit_broadcast = false; - flag_list = new std::vector({ - tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast", - &flags->xla_eliminate_hlo_implicit_broadcast, - "Eliminate implicit broadcast on when lowering user " - "computation to HLO instructions, use explicit " - "broadcast instead."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline -// module. -void AppendUserComputationFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the UserComputationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -UserComputationFlags* GetUserComputationFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h deleted file mode 100644 index f5222c927cb203b901fb3bc6ea3d2e7d30cb658a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ - -// Legacy flags for XLA's user_computation module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flags definitions associated with XLA's user_computation -// module. -void AppendUserComputationFlags(std::vector* flag_list); - -typedef struct { - // Eliminate implicit broadcast on when lowering user computation to HLO - // instructions, use explicit broadcast instead. - bool xla_eliminate_hlo_implicit_broadcast; -} UserComputationFlags; - -// Return a pointer to the UserComputationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -UserComputationFlags* GetUserComputationFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/util_flags.cc b/tensorflow/compiler/xla/legacy_flags/util_flags.cc deleted file mode 100644 index e6df19ddd2afbbf14149d77a1e0652df209f58fe..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/util_flags.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's util module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/legacy_flags/util_flags.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static UtilFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new UtilFlags; - flags->xla_status_add_backtrace = false; - flag_list = new std::vector({ - tensorflow::Flag("xla_status_add_backtrace", - &flags->xla_status_add_backtrace, - "add backtraces to XLA-produced status values"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's util module. -void AppendUtilFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the UtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -UtilFlags* GetUtilFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/util_flags.h b/tensorflow/compiler/xla/legacy_flags/util_flags.h deleted file mode 100644 index 03bffcd726f0544a185f5e8403ad2c45318bd0ad..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/util_flags.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_ - -// Legacy flags for the XLA's util module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's util module. -void AppendUtilFlags(std::vector* flag_list); - -// The values of flags associated with XLA's util module. -typedef struct { - bool xla_status_add_backtrace; // add backtraces to XLA-produced statuses -} UtilFlags; - -// Return a pointer to the UtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -UtilFlags* GetUtilFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index caef3a3869f4bcde7a6982ce3dfc0db9d36cbc5e..0db9bd757d420d8ecf281b6ec936c3f34ee23617 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -62,7 +62,17 @@ Literal::StrideConfig::StrideConfig( std::unique_ptr Literal::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(); *literal->mutable_shape() = shape; - literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); + if (ShapeUtil::IsTuple(shape)) { + int64 num_elements = ShapeUtil::TupleElementCount(shape); + literal->tuple_literals_.resize(num_elements); + for (int i = 0; i < num_elements; ++i) { + std::unique_ptr elem = + CreateFromShape(ShapeUtil::GetTupleElementShape(shape, i)); + literal->tuple_literals_[i] = std::move(*elem); + } + } else { + literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); + } return literal; } @@ -321,6 +331,7 @@ Status Literal::Copy(const Literal& src_literal, } std::unique_ptr Literal::Relayout(const Layout& layout) const { + CHECK(ShapeUtil::IsArray(shape())); std::unique_ptr result = CloneToUnique(); *result->mutable_shape()->mutable_layout() = layout; @@ -620,6 +631,18 @@ string Literal::ToString() const { return literal; } +/* static */ std::unique_ptr Literal::MakeTupleOwned( + std::vector> elements) { + auto literal = MakeUnique(); + std::vector shape; + for (auto& tuple_element : elements) { + shape.push_back(tuple_element->shape()); + literal->add_tuple_literals()->Swap(tuple_element.get()); + } + *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape); + return literal; +} + const void* Literal::InternalData() const { return const_cast( const_cast(this)->MutableInternalData()); @@ -630,7 +653,6 @@ void* Literal::MutableInternalData() { // created by the accessor functions. switch (shape().element_type()) { case PRED: - return reinterpret_cast(preds_.data()); case U8: return reinterpret_cast(u8s_.data()); case S32: @@ -698,8 +720,6 @@ tensorflow::Status Literal::ValidateLiteral() const { int64 actual = -1; switch (shape().element_type()) { case PRED: - actual = preds_size(); - break; case U8: actual = u8s_size(); break; @@ -754,10 +774,30 @@ void Literal::EachCellAsString( } namespace { +template +std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { + auto result_literal = MakeUnique(); + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = src_literal.shape(); + result_shape->set_element_type( + primitive_util::NativeToPrimitiveType()); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + tensorflow::gtl::ArraySlice src_data = + src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = static_cast(src_data[i]); + } + return result_literal; +} + template std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< + return ConvertBetweenNativeTypes< typename primitive_util::PrimitiveTypeToNative::type, typename primitive_util::PrimitiveTypeToNative< primitive_dest_type>::type>(src_literal); @@ -782,19 +822,20 @@ StatusOr> ConvertIfDestTypeMatches( #undef CONVERT_IF_TYPES_MATCH // Other types are not yet supported. default: - return tensorflow::errors::InvalidArgument( - "Unimplemented: ConvertIfDestTypeMatches for type " + - PrimitiveType_Name(src_literal.shape().element_type())); + return InvalidArgument( + "Unimplemented: Convert from type %s to type %s", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } } -} +} // namespace -StatusOr> LiteralUtil::ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { +StatusOr> Literal::Convert( + PrimitiveType primitive_dest_type) const { + switch (shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); + return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type); CONVERT_IF_DEST_TYPE_MATCHES(PRED) CONVERT_IF_DEST_TYPE_MATCHES(S8) CONVERT_IF_DEST_TYPE_MATCHES(S32) @@ -807,9 +848,9 @@ StatusOr> LiteralUtil::ConvertIfSrcTypeMatches( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return tensorflow::errors::InvalidArgument( - "Unimplemented: ConvertIfSrcTypeMatches for type " + - PrimitiveType_Name(src_literal.shape().element_type())); + return InvalidArgument("Unimplemented: Convert from type %s to type %s", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } } @@ -884,26 +925,22 @@ bool Literal::Equal(const Literal& literal2) const { template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { auto values = mutable_preds(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(values->data()), values->size()); } template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - // C++11 standard, basic_string 21.4.1.5, values should be stored - // contiguously. From C++17 a mutable data() member will be provided. auto values = mutable_u8s(); return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(&(*values)[0]), values->size()); + reinterpret_cast(values->data()), values->size()); } template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - // C++11 standard, basic_string 21.4.1.5, values should be stored - // contiguously. From C++17 a mutable data() member will be provided. auto values = mutable_u8s(); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(&(*values)[0]), values->size()); + return tensorflow::gtl::MutableArraySlice(values->data(), + values->size()); } template <> @@ -965,19 +1002,18 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - // C++11 standard, basic_string 21.4.1.5, values should be stored - // contiguously. From C++17 a mutable data() member will be provided. // TODO - there is an endianess problem here. fix it, or wait for uint16 // support in protobuf auto values = mutable_f16s(); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(&(*values)[0]), values->size() / sizeof(half)); + return tensorflow::gtl::MutableArraySlice(values->data(), + values->size()); } template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { CHECK_EQ(shape().element_type(), PRED); - return tensorflow::gtl::ArraySlice(preds().data(), preds().size()); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(preds().data()), preds().size()); } template <> @@ -1027,9 +1063,8 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { CHECK_EQ(shape().element_type(), F16); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(f16s().data()), - f16s().size() / sizeof(half)); + return tensorflow::gtl::ArraySlice(f16s().data(), + f16s().size() / sizeof(half)); } template @@ -1192,21 +1227,13 @@ static void CopyToRepeatedField(RepeatedFieldT* dest, *dest = RepeatedFieldT(src.begin(), src.end()); } -template -static void CopyToRepeatedBoolField(RepeatedFieldT* dest, - const BoolVector& src) { - *dest = RepeatedFieldT(src.begin(), src.end()); -} - LiteralProto Literal::ToProto() const { LiteralProto proto; proto.Clear(); *proto.mutable_shape() = shape(); switch (shape().element_type()) { case PRED: - if (preds().begin()) { - CopyToRepeatedBoolField(proto.mutable_preds(), preds()); - } + CopyToRepeatedField(proto.mutable_preds(), preds()); break; case U8: *proto.mutable_u8s() = u8s_string(); @@ -1260,8 +1287,7 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { *mutable_shape() = literal_proto.shape(); switch (shape().element_type()) { case PRED: - *mutable_preds() = BoolVector(literal_proto.preds().begin(), - literal_proto.preds().end()); + CopyFromRepeatedField(mutable_preds(), literal_proto.preds()); break; case U8: set_u8s(literal_proto.u8s()); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 42c8b61acec8f4dc661111affc17773b1aa71583..125c268573becad622d880aab9a7f3dd18ab68df 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -49,94 +49,6 @@ limitations under the License. namespace xla { -// This class is a simple vector of boolean values. It's used to workaround some -// implementations of std::vector that use a bitset which does not have -// the semantics expected by Literal::preds(). -class BoolVector { - public: - typedef bool* iterator; - typedef const bool* const_iterator; - - BoolVector() : bits_(nullptr), size_(0), capacity_(0) {} - - BoolVector(const_iterator other_begin, const_iterator other_end) - : bits_(nullptr), size_(0), capacity_(0) { - if (other_begin && other_end) { - resize(other_end - other_begin); - memcpy(begin(), other_begin, size()); - } - } - - BoolVector(const BoolVector& other) { CopyFrom(other); } - - BoolVector& operator=(const BoolVector& other) { - CopyFrom(other); - return *this; - } - - void push_back(const bool& value) { - resize(size_ + 1); - bits_[size_ - 1] = value; - } - - bool* data() const { return bits_.get(); } - - size_t size() const { return size_; } - - size_t capacity() const { return capacity_; } - - void resize(size_t new_size, bool val = false) { - if (new_size == 0) { - bits_.reset(nullptr); - size_ = 0; - capacity_ = 0; - } else { - size_t old_size = size(); - if (new_size > old_size) { - grow(new_size); - } - if (old_size < new_size) { - memset(&bits_[old_size], val, new_size - old_size); - } - size_ = new_size; - } - } - - void clear() { - bits_.reset(nullptr); - size_ = 0; - capacity_ = 0; - } - - iterator begin() { return &bits_[0]; } - iterator end() { return &bits_[size()]; } - const_iterator begin() const { return &bits_[0]; } - const_iterator end() const { return &bits_[size()]; } - - private: - void grow(size_t n) { - if (capacity_ < n) { - capacity_ = 2 * n; - bool* new_bits = new bool[capacity_](); - if (size_ > 0) { - memcpy(new_bits, bits_.get(), size_); - } - bits_.reset(new_bits); - } - } - - void CopyFrom(const BoolVector& other) { - bits_ = MakeUnique(other.capacity()); - memcpy(begin(), other.begin(), other.size()); - size_ = other.size(); - capacity_ = other.capacity(); - } - - std::unique_ptr bits_; - size_t size_; - size_t capacity_; -}; - // Utility class for dealing with XLA literal values. Most methods are // templated by native (host) type which corresponds to a unique XLA // PrimitiveType. See ComputationBuilder for details. Not all primitive types @@ -147,10 +59,12 @@ class Literal { Literal() {} Literal(const Literal& other) = default; + Literal(Literal&&) = default; explicit Literal(const LiteralProto& other) { CopyFromProto(other); } Literal& operator=(const Literal& other) = default; + Literal& operator=(Literal&&) = default; LiteralProto ToProto() const; @@ -165,7 +79,6 @@ class Literal { void Clear() { shape_.Clear(); - preds_.clear(); u8s_.clear(); s32s_.clear(); s64s_.clear(); @@ -177,9 +90,17 @@ class Literal { tuple_literals_.clear(); } - int preds_size() const { return preds().size(); } - const BoolVector& preds() const { return preds_; } - BoolVector* mutable_preds() { return &preds_; } + int preds_size() const { return u8s().size(); } + const std::vector& preds() const { + static_assert(sizeof(uint8) == sizeof(bool), + "The uint8 and bool types should be the same size"); + return u8s_; + } + std::vector* mutable_preds() { + static_assert(sizeof(uint8) == sizeof(bool), + "The uint8 and bool types should be the same size"); + return &u8s_; + } int s32s_size() const { return s32s().size(); } int32 s32s(int i) const { return s32s_[i]; } @@ -251,7 +172,7 @@ class Literal { *other = temp; } - // CreatesCreate new literal of a given rank. To minimize ambiguity (for users + // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the // native type. For example: // @@ -362,10 +283,10 @@ class Literal { template std::unique_ptr Replicate(int64 times) const; - // Creates a literal by converting each element in this literal to a new - // type. - template - std::unique_ptr Convert() const; + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; // Creates a literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -444,10 +365,21 @@ class Literal { template void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - // Retrieves the mutable array slice interface which can be used to manipulate - // pre-allocated literal values. + // Returns a (Mutable)ArraySlice view of the array for this literal for the + // given NativeT (e.g., float). These functions map native type to XLA + // PrimitiveType via template specialization. The unspecialized forms below + // aborts to handle the error case where the given native type does not map to + // an XLA primitive type. template - tensorflow::gtl::MutableArraySlice GetMutableArraySlice(); + tensorflow::gtl::ArraySlice GetArraySlice() const { + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } + template + tensorflow::gtl::MutableArraySlice GetMutableArraySlice() { + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -466,6 +398,16 @@ class Literal { static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); + // As above, but intended to be invoked with move semantics; i.e. + // + // std::vector> elements = ...; + // auto result = Literal::MakeTupleOwned(std::move(elements)); + // + // This would have been declared as an overload, but there is ambiguity + // in invocation between the above signature and this one. + static std::unique_ptr MakeTupleOwned( + std::vector> elements); + // Validates that the data payload of the literal matches the literal shape; // if it does not, an appropriate status is returned. tensorflow::Status ValidateLiteral() const; @@ -588,17 +530,6 @@ class Literal { bool IsZero(tensorflow::gtl::ArraySlice indices) const; private: - // Returns an ArraySlice view of the array for this literal for the given - // NativeT (e.g., float). These functions map native type to XLA PrimitiveType - // via template specialization. The unspecialized forms below aborts to handle - // the error case where the given native type does not map to an XLA primitive - // type. - template - tensorflow::gtl::ArraySlice GetArraySlice() const { - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } - // Copy from a LiteralProto instance. void CopyFromProto(const LiteralProto& literal_proto); @@ -634,7 +565,6 @@ class Literal { }; Shape shape_; - BoolVector preds_; std::vector u8s_; std::vector s32s_; std::vector s64s_; @@ -646,544 +576,6 @@ class Literal { std::vector tuple_literals_; }; -// Utility class for dealing with XLA literal values. Most methods are -// templated by native (host) type which corresponds to a unique XLA -// PrimitiveType. See ComputationBuilder for details. Not all primitive types -// defined in xla_data.proto have a corresponding native type or even have a -// storage location in the Literal proto yet (for example, primitive type F16). -// -// TODO(dnovillo) - All functions in this class simply redirect to the -// corresponding function in class Literal. Remove this class after converting -// all user code to use Literal directly. -class LiteralUtil { - public: - // Creates new literal of a given rank. To minimize ambiguity (for users and - // the compiler) these CreateR[0-2] methods should explicitly specify the - // native type. For example: - // - // CreateR1({1.0, 42.0}); - // CreateR2({{1, 2}, {3, 4}}); - // - // The variants not ending with WithLayout use the default XLA layout for the - // literal's linear representation in memory. - template - static std::unique_ptr CreateR0(NativeT value) { - return Literal::CreateR0(value); - } - - template - static std::unique_ptr CreateR1( - tensorflow::gtl::ArraySlice values) { - return Literal::CreateR1(values); - } - - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values) { - return Literal::CreateR1(values); - } - - template - static std::unique_ptr CreateR2( - std::initializer_list> values) { - return Literal::CreateR2(values); - } - - template - static std::unique_ptr CreateR2WithLayout( - std::initializer_list> values, - const Layout& layout) { - return Literal::CreateR2WithLayout(values, layout); - } - - template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values) { - return Literal::CreateR3(values); - } - - template - static std::unique_ptr CreateR3WithLayout( - std::initializer_list< - std::initializer_list>> - values, - const Layout& layout) { - return Literal::CreateR3WithLayout(values, layout); - } - - template - static std::unique_ptr CreateR4( - std::initializer_list>>> - values) { - return Literal::CreateR4(values); - } - - template - static std::unique_ptr CreateR4WithLayout( - std::initializer_list>>> - values, - const Layout& layout) { - return Literal::CreateR4WithLayout(values, layout); - } - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape) { - return Literal::CreateFromShape(shape); - } - - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions) { - return Literal::CreateFromDimensions(primitive_type, dimensions); - } - - // Copies the values from src_literal, starting at src_base shape indexes, - // to dest_literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // - // The src_literal and dest_literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - static Status Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - return dest_literal->Copy(src_literal, src_base, dest_base, copy_size); - } - - // Creates a new value that has the equivalent value as literal, but conforms - // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major - // dimension layout can be re-laid-out as {1, 0} minor-to-major dimension - // layout and the value in the cell at any given logical index (i0, i1) will - // be the same. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - static std::unique_ptr Relayout(const Literal& literal, - const Layout& new_layout) { - return literal.Relayout(new_layout); - } - - // Reshapes literal 'input' to have 'shape'. Both the original shape and - // 'shape' must contain the same number of elements. The implementation - // currently only supports monotonic dim0-major layouts. - static StatusOr> Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice shape) { - return input.Reshape(shape); - } - - // Creates a new literal by reordering the dimensions of the original literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - static std::unique_ptr Transpose( - const Literal& literal, tensorflow::gtl::ArraySlice permutation) { - return literal.Transpose(permutation); - } - - // Creates a sub-array from the given literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - static std::unique_ptr Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) { - return literal.Slice(start_indices, limit_indices); - } - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input - // literal replicated four times. - template - static std::unique_ptr Replicate(const Literal& input, int64 times) { - return input.Replicate(times); - } - - // Creates a literal by converting each element in an original literal to a - // new type. - template - static std::unique_ptr Convert(const Literal& literal) { - return literal.Convert(); - } - - // Convert a literal to another primitive type, but only if the literal - // type is connvertable into the destination type - static StatusOr> ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type); - - // Creates a literal value zero of the given primitive type. - static Literal Zero(PrimitiveType primitive_type) { - return Literal::Zero(primitive_type); - } - - // Creates a literal value one of the given primitive type. - static Literal One(PrimitiveType primitive_type) { - return Literal::One(primitive_type); - } - - // Creates a literal value containing the minimum value of the given - // primitive type. For floating-point types, returns -inf. - static Literal MinValue(PrimitiveType primitive_type) { - return Literal::MinValue(primitive_type); - } - - // Creates a literal value containing the maximum value of the given - // primitive type. For floating-point types, returns inf. - static Literal MaxValue(PrimitiveType primitive_type) { - return Literal::MaxValue(primitive_type); - } - - // Creates a literal of the given shape where each element is `value`. - template - static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value) { - return Literal::CreateFullWithMonotonicDim0MajorLayout(dimensions, value); - } - - // Creates a new literal from an array. The variants not ending with - // WithLayout use the default XLA layout for the literal's linear - // representation in memory. - template - static std::unique_ptr CreateR2FromArray2D( - const Array2D& values) { - return Literal::CreateR2FromArray2D(values); - } - - template - static std::unique_ptr CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return Literal::CreateR2FromArray2DWithLayout(values, layout); - } - - template - static std::unique_ptr CreateR3FromArray3D( - const Array3D& values) { - return Literal::CreateR3FromArray3D(values); - } - - template - static std::unique_ptr CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return Literal::CreateR3FromArray3DWithLayout(values, layout); - } - - template - static std::unique_ptr CreateR4FromArray4D( - const Array4D& values) { - return Literal::CreateR4FromArray4D(values); - } - - template - static std::unique_ptr CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return Literal::CreateR4FromArray4DWithLayout(values, layout); - } - - // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(tensorflow::StringPiece value) { - return Literal::CreateR1U8(value); - } - - // Creates a linspace-populated literal with the given number of rows and - // columns. - static std::unique_ptr CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols) { - return Literal::CreateR2F32Linspace(from, to, rows, cols); - } - - // Creates a literal that projects the (x, y) dimensions given in values into - // the z dimension given by "projection". - template - static std::unique_ptr CreateR3Projected( - std::initializer_list> values, - int64 projection) { - return Literal::CreateR3Projected(values, projection); - } - - // Creates a literal that projects the (x, y) dimensions given in values into - // the z and p dimensions given. - template - static std::unique_ptr CreateR4Projected( - std::initializer_list> values, - int64 projection_p, int64 projection_z) { - return Literal::CreateR4Projected(values, projection_p, projection_z); - } - - // Clones literal into an owned unique_ptr version. - static std::unique_ptr CloneToUnique(const Literal& literal) { - return literal.CloneToUnique(); - } - - // Returns the linear index of the given index within the literal's - // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.LinearIndex(multi_index); - } - - // Gets or sets an element in the literal at the given index. The index is - // CHECKed against the dimension sizes. - template - static NativeT Get(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.Get(multi_index); - } - - template - static void Set(Literal* literal, - tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - literal->Set(multi_index, value); - } - - // Retrieves the mutable array slice interface which can be used to manipulate - // pre-allocated literal values. - template - static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( - Literal* literal) { - return literal->GetMutableArraySlice(); - } - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - static NativeT GetFirstElement(const Literal& literal) { - return literal.GetFirstElement(); - } - - // As Get(), but determines the correct type and converts the value - // into text. - static string GetAsString(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.GetAsString(multi_index); - } - - // Returns an identity matrix (rank 2) with the given row and column count. - template - static std::unique_ptr MakeIdentityR2(int64 size) { - return Literal::MakeIdentityR2(size); - } - - // Returns a tuple literal composed of given literals. - static std::unique_ptr MakeTuple( - tensorflow::gtl::ArraySlice elements) { - return Literal::MakeTuple(elements); - } - - // Validates that the data payload of the literal matches the literal shape; - // if it does not, an appropriate status is returned. - static tensorflow::Status ValidateLiteral(const Literal& literal) { - return literal.ValidateLiteral(); - } - - // Returns a string representation of the literal value. - static string ToString(const Literal& literal) { return literal.ToString(); } - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - static void EachCellAsString( - const Literal& literal, - const std::function indices, - const string& value)>& per_cell) { - literal.EachCellAsString(per_cell); - } - - template - static void EachCell( - const Literal& literal, - std::function indices, - NativeT value)> - per_cell) { - literal.EachCell(per_cell); - } - - // Templated methods which populate the given repeated field in the Literal - // proto with the given value(s). The Shape field of the Literal proto is set - // to match the array dimensions and type. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // PopulateR2FromArray2D(values, literal); - // - // // Populate with int32s. - // PopulateR2({{1, 2}, {3, 4}}, literal); - // - template - static void PopulateR0(NativeT values, Literal* literal) { - literal->PopulateR0(values); - } - - template - static void PopulateR1(tensorflow::gtl::ArraySlice values, - Literal* literal) { - literal->PopulateR1(values); - } - - static void PopulateR1(const tensorflow::core::Bitmap& values, - Literal* literal) { - literal->PopulateR1(values); - } - - template - static void PopulateR2( - std::initializer_list> values, - Literal* literal) { - literal->PopulateR2(values); - } - - template - static void PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout, Literal* literal) { - literal->PopulateR2WithLayout(values, layout); - } - - template - static void PopulateR2FromArray2D(const Array2D& values, - Literal* literal) { - literal->PopulateR2FromArray2D(values); - } - - template - static void PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); - } - - template - static void PopulateR3FromArray3D(const Array3D& values, - Literal* literal) { - literal->PopulateR3FromArray3D(values); - } - - template - static void PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - } - - template - static void PopulateR4FromArray4D(const Array4D& values, - Literal* literal) { - literal->PopulateR4FromArray4D(values); - } - - template - static void PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - } - - // Populates literal values by calling the generator function for every cell - // in the literal object. - template - static Status Populate( - Literal* literal, - const std::function indexes)>& - generator) { - return literal->Populate(generator); - } - - // Creates a Literal of the given dimensions with all elements set to the - // given value. - template - static void PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - return literal->PopulateWithValue(value, dimensions); - } - - // Returns a pointer to the underlying vector containing the array data. Use - // with care. - static const void* InternalData(const Literal& literal) { - return literal.InternalData(); - } - - static void* MutableInternalData(Literal* literal) { - return literal->MutableInternalData(); - } - - // Allocates space in the underlying vector of the literal sufficient to hold - // num_elements of the literal's primitive type. Values in the vector are set - // to zero. num_elements must equal the number of elements in the literals - // shape. - static void Reserve(int64 num_elements, Literal* literal) { - literal->Reserve(num_elements); - } - - // Allocates space in the underlying vector of the literal sufficient to hold - // num_elements of the literal's primitive type and sets each element in the - // literal to the given value. num_elements must equal the number of elements - // in the literals shape. - template - static void Resize(int64 num_elements, NativeT value, Literal* literal) { - literal->Resize(num_elements, value); - } - - // Returns true if the two given literals have the same shape and - // values. Layout is not considered in the comparison. - static bool Equal(const Literal& literal1, const Literal& literal2) { - return literal1.Equal(literal2); - } - - // Returns whether every element in the given literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in literal's type, returns false. Values of 1/0 are - // considered equal to true/false; other values are not considered equal to - // true. - static bool IsAll(const Literal& literal, int8 value) { - return literal.IsAll(value); - } - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. - static bool IsAllFloat(const Literal& literal, float value) { - return literal.IsAllFloat(value); - } - - // Returns whether the literal is zero at the specified index. The literal - // must be an array. - static bool IsZero(const Literal& literal, - tensorflow::gtl::ArraySlice indices) { - return literal.IsZero(indices); - } - - TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); -}; - // Declarations of template specializations for GetArraySlice and // GetMutableArraySlice. The specializations map native type to XLA primitive // type. @@ -1759,27 +1151,6 @@ void Literal::PopulateWithValue(NativeT value, Resize(ShapeUtil::ElementsIn(shape()), value); } -template -std::unique_ptr Literal::Convert() const { - const Shape& this_shape = shape(); - auto result_literal = MakeUnique(); - Shape* result_shape = result_literal->mutable_shape(); - *result_shape = this_shape; - result_shape->set_element_type( - primitive_util::NativeToPrimitiveType()); - result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); - tensorflow::gtl::ArraySlice src_data = - GetArraySlice(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->GetMutableArraySlice(); - int64 num_elements = ShapeUtil::ElementsIn(this_shape); - - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast(src_data[i]); - } - return result_literal; -} - template /* static */ std::unique_ptr Literal::CreateFullWithMonotonicDim0MajorLayout( diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 8d4a75d7affebd3ee39702cb1226ee52aff09691..b50e741b8ad55173d932231836abd5996cf1a068 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -72,11 +72,11 @@ class LiteralUtilTest : public ::testing::Test { layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3}); literal_r4_2x2x3x3_dim0major_ = - LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0major_); + Literal::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0major_); literal_r4_2x2x3x3_dim0minor_ = - LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0minor_); + Literal::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0minor_); } Layout layout_r2_dim0major_; @@ -90,43 +90,42 @@ class LiteralUtilTest : public ::testing::Test { }; TEST_F(LiteralUtilTest, LiteralScalarToString) { - auto true_lit = LiteralUtil::CreateR0(true); - ASSERT_EQ("true", LiteralUtil::ToString(*true_lit)); + auto true_lit = Literal::CreateR0(true); + ASSERT_EQ("true", true_lit->ToString()); - auto false_lit = LiteralUtil::CreateR0(false); - ASSERT_EQ("false", LiteralUtil::ToString(*false_lit)); + auto false_lit = Literal::CreateR0(false); + ASSERT_EQ("false", false_lit->ToString()); - auto u32_lit = LiteralUtil::CreateR0(42); - ASSERT_EQ("42", LiteralUtil::ToString(*u32_lit)); + auto u32_lit = Literal::CreateR0(42); + ASSERT_EQ("42", u32_lit->ToString()); - auto s32_lit = LiteralUtil::CreateR0(-999); - ASSERT_EQ("-999", LiteralUtil::ToString(*s32_lit)); + auto s32_lit = Literal::CreateR0(-999); + ASSERT_EQ("-999", s32_lit->ToString()); - auto f32_lit = LiteralUtil::CreateR0(3.14f); - ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); + auto f32_lit = Literal::CreateR0(3.14f); + ASSERT_EQ("3.14", f32_lit->ToString()); - auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", LiteralUtil::ToString(*f16_lit)); + auto f16_lit = Literal::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", f16_lit->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { - auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - ASSERT_EQ("{101}", LiteralUtil::ToString(*pred_vec)); + auto pred_vec = Literal::CreateR1({true, false, true}); + ASSERT_EQ("{101}", pred_vec->ToString()); } TEST_F(LiteralUtilTest, R2ToString) { - const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); const string expected = R"(s32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 }, })"; - ASSERT_EQ(expected, LiteralUtil::ToString(*literal)); + ASSERT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, R3ToString) { - const auto literal = - LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); + const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { { { 1 }, { 2 } }, @@ -135,13 +134,13 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - ASSERT_EQ(expected, LiteralUtil::ToString(*literal)); + ASSERT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, TupleToString) { - auto scalar = LiteralUtil::CreateR0(1.0); - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -149,7 +148,7 @@ f32[2,2] { { 3, 4 }, }, ))"; - ASSERT_EQ(expected, LiteralUtil::ToString(*tuple)); + ASSERT_EQ(expected, tuple->ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -164,9 +163,9 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { }); // clang-format on - auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); + auto literal = Literal::CreateR3FromArray3D(array_3d); EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); - string result = LiteralUtil::ToString(*literal); + string result = literal->ToString(); const string expected = R"(f32[2,3,2] { { { 1, 2 }, { 3, 4 }, @@ -180,14 +179,14 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off - auto literal = LiteralUtil::CreateR4Projected({ + auto literal = Literal::CreateR4Projected({ {1, 2}, {1001, 1002}, {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); - string result = LiteralUtil::ToString(*literal); + string result = literal->ToString(); const string expected = R"(f32[1,2,3,2] { { // i0=0 { // i1=0 @@ -208,7 +207,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), ElementsAre(2, 2, 3, 3)); - string result = LiteralUtil::ToString(*literal_r4_2x2x3x3_dim0major_); + string result = literal_r4_2x2x3x3_dim0major_->ToString(); const string expected = R"(f32[2,2,3,3] { { // i0=0 { // i1=0 @@ -240,14 +239,13 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format off - auto literal = LiteralUtil::CreateR2({ + auto literal = Literal::CreateR2({ {3.1f, 4.2f}, {9.3f, 12.4f}, }); // clang-format on std::vector> seen; - LiteralUtil::EachCellAsString( - *literal, + literal->EachCellAsString( [&seen](tensorflow::gtl::ArraySlice indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -259,176 +257,171 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { } TEST_F(LiteralUtilTest, ScalarEquality) { - // Test LiteralUtil::Equal with scalars. - auto f32_42 = LiteralUtil::CreateR0(42.0); - auto f32_42_clone = LiteralUtil::CreateR0(42.0); + // Test Literal::Equal with scalars. + auto f32_42 = Literal::CreateR0(42.0); + auto f32_42_clone = Literal::CreateR0(42.0); - EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42)); - EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42_clone)); + EXPECT_TRUE(f32_42->Equal(*f32_42)); + EXPECT_TRUE(f32_42->Equal(*f32_42_clone)); - auto f32_123 = LiteralUtil::CreateR0(123.0); - EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f32_123)); + auto f32_123 = Literal::CreateR0(123.0); + EXPECT_FALSE(f32_42->Equal(*f32_123)); - auto f64_42 = LiteralUtil::CreateR0(42.0); - EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f64_42)); + auto f64_42 = Literal::CreateR0(42.0); + EXPECT_FALSE(f32_42->Equal(*f64_42)); } TEST_F(LiteralUtilTest, NonScalarEquality) { - // Test LiteralUtil::Equal with nonscalars. - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_clone = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_different = - LiteralUtil::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); - auto vector_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); - auto scalar = LiteralUtil::CreateR0(1.0); - - EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix)); - EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix_clone)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *matrix_different)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *vector_literal)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *scalar)); + // Test Literal::Equal with nonscalars. + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_clone = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_different = Literal::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); + auto vector_literal = Literal::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto scalar = Literal::CreateR0(1.0); + + EXPECT_TRUE(matrix->Equal(*matrix)); + EXPECT_TRUE(matrix->Equal(*matrix_clone)); + EXPECT_FALSE(matrix->Equal(*matrix_different)); + EXPECT_FALSE(matrix->Equal(*vector_literal)); + EXPECT_FALSE(matrix->Equal(*scalar)); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { - // Test LiteralUtil::Equal with literals which have different layouts. + // Test Literal::Equal with literals which have different layouts. auto colmajor = MakeUnique(); *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - LiteralUtil::Reserve(4, colmajor.get()); - LiteralUtil::Set(colmajor.get(), {0, 0}, 1.0); - LiteralUtil::Set(colmajor.get(), {0, 1}, 2.0); - LiteralUtil::Set(colmajor.get(), {1, 0}, 3.0); - LiteralUtil::Set(colmajor.get(), {1, 1}, 4.0); + colmajor->Reserve(4); + colmajor->Set({0, 0}, 1.0); + colmajor->Set({0, 1}, 2.0); + colmajor->Set({1, 0}, 3.0); + colmajor->Set({1, 1}, 4.0); auto rowmajor = MakeUnique(); *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - LiteralUtil::Reserve(4, rowmajor.get()); - LiteralUtil::Set(rowmajor.get(), {0, 0}, 1.0); - LiteralUtil::Set(rowmajor.get(), {0, 1}, 2.0); - LiteralUtil::Set(rowmajor.get(), {1, 0}, 3.0); - LiteralUtil::Set(rowmajor.get(), {1, 1}, 4.0); + rowmajor->Reserve(4); + rowmajor->Set({0, 0}, 1.0); + rowmajor->Set({0, 1}, 2.0); + rowmajor->Set({1, 0}, 3.0); + rowmajor->Set({1, 1}, 4.0); - EXPECT_TRUE(LiteralUtil::Equal(*rowmajor, *colmajor)); + EXPECT_TRUE(rowmajor->Equal(*colmajor)); } TEST_F(LiteralUtilTest, TupleEquality) { - // Test LiteralUtil::Equal with tuples. - auto scalar = LiteralUtil::CreateR0(1.0); - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + // Test Literal::Equal with tuples. + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto scalar_clone = LiteralUtil::CreateR0(1.0); - auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_TRUE(LiteralUtil::Equal(*tuple1, *tuple2)); + auto scalar_clone = Literal::CreateR0(1.0); + auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); + EXPECT_TRUE(tuple1->Equal(*tuple2)); // Tuple with elements reversed. - auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *reversed_tuple)); + auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); + EXPECT_FALSE(tuple1->Equal(*reversed_tuple)); // Tuple with different value. - auto scalar_42 = LiteralUtil::CreateR0(42.0); - auto different_tuple = - LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *different_tuple)); + auto scalar_42 = Literal::CreateR0(42.0); + auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); + EXPECT_FALSE(tuple1->Equal(*different_tuple)); } TEST_F(LiteralUtilTest, IsAllTuple) { - auto element1 = LiteralUtil::CreateR0(0.0); - auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + auto element1 = Literal::CreateR0(0.0); + auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + auto tuple = Literal::MakeTuple({element1.get(), element1.get()}); // Tuples should always return false for IsAll. - EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 0)); - EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 1)); + EXPECT_FALSE(tuple->IsAll(0)); + EXPECT_FALSE(tuple->IsAll(1)); +} + +// Verifies that CreateFromShape works for tuples. +TEST_F(LiteralUtilTest, CreateFromShapeTuple) { + auto scalar = Literal::CreateR0(0.0); + auto matrix = Literal::CreateR2({{0, 0}, {0, 0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + + auto x = Literal::CreateFromShape(tuple->shape()); + EXPECT_TRUE(tuple->Equal(*x)); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 0)); - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 1)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 1)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 2)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 0)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 2)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), -1)); + EXPECT_TRUE(Literal::CreateR0(false)->IsAll(0)); + EXPECT_TRUE(Literal::CreateR0(true)->IsAll(1)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAll(1)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAll(2)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(0)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(2)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR0(255), int8_min)); + EXPECT_FALSE(Literal::CreateR0(255)->IsAll(int8_min)); - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(42.0), 42)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(42.0001), 42)); + EXPECT_TRUE(Literal::CreateR0(42.0)->IsAll(42)); + EXPECT_FALSE(Literal::CreateR0(42.0001)->IsAll(42)); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR1({100, 100, 100}), 100)); - EXPECT_FALSE(LiteralUtil::IsAll( - *LiteralUtil::CreateR1({100, 100, 100.001}), 100)); + EXPECT_TRUE(Literal::CreateR1({100, 100, 100})->IsAll(100)); + EXPECT_FALSE(Literal::CreateR1({100, 100, 100.001})->IsAll(100)); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{8, 8}, {8, 8}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{8, 8}, {8, 9}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{9, 8}, {8, 8}}), 8)); + EXPECT_TRUE(Literal::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h8}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h9}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h9}, {h8}}), 8)); + EXPECT_TRUE(Literal::CreateR2({{h8}, {h8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); auto uint64_max = std::numeric_limits::max(); - EXPECT_FALSE(LiteralUtil::IsAll( - *LiteralUtil::CreateR2( - {{uint64_max, uint64_max}, {uint64_max, uint64_max}}), - -1)); + EXPECT_FALSE(Literal::CreateR2( + {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) + ->IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(false), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(.5), .5)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.5)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + + EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.49)); - EXPECT_FALSE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}), .5)); - - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(.5), .5)); + Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); EXPECT_TRUE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.5)); + Literal::CreateR2({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5)); + + EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.49)); - EXPECT_FALSE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}), 0)); + Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsZero) { - auto scalar_zero = LiteralUtil::CreateR0(0.0f); - auto scalar_one = LiteralUtil::CreateR0(1.0f); - EXPECT_TRUE(LiteralUtil::IsZero(*scalar_zero, {})); - EXPECT_FALSE(LiteralUtil::IsZero(*scalar_one, {})); - - auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); - EXPECT_FALSE(LiteralUtil::IsZero(*array, {0, 1})); - EXPECT_TRUE(LiteralUtil::IsZero(*array, {0, 2})); - EXPECT_TRUE(LiteralUtil::IsZero(*array, {1, 1})); - EXPECT_FALSE(LiteralUtil::IsZero(*array, {1, 2})); + auto scalar_zero = Literal::CreateR0(0.0f); + auto scalar_one = Literal::CreateR0(1.0f); + EXPECT_TRUE(scalar_zero->IsZero({})); + EXPECT_FALSE(scalar_one->IsZero({})); + + auto array = Literal::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); + EXPECT_FALSE(array->IsZero({0, 1})); + EXPECT_TRUE(array->IsZero({0, 2})); + EXPECT_TRUE(array->IsZero({1, 1})); + EXPECT_FALSE(array->IsZero({1, 2})); } template @@ -440,127 +433,122 @@ TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. TypeParam half = TypeParam(1) / TypeParam(2); - auto data = LiteralUtil::CreateR2({{half, 2}, {3, 4}}); + auto data = Literal::CreateR2({{half, 2}, {3, 4}}); const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); - auto data01 = LiteralUtil::Relayout(*data, layout01); + auto data01 = data->Relayout(layout01); EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_TRUE(LiteralUtil::Equal(*data, *data01)); + EXPECT_TRUE(data->Equal(*data01)); - auto data10 = LiteralUtil::Relayout(*data, layout10); + auto data10 = data->Relayout(layout10); EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_TRUE(LiteralUtil::Equal(*data, *data10)); + EXPECT_TRUE(data->Equal(*data10)); } TEST_F(LiteralUtilTest, ReshapeR0) { - auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = - LiteralUtil::Reshape(*original, /*shape=*/{}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape)); + auto original = Literal::CreateR0(1.7f); + auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); + EXPECT_TRUE(original->Equal(*reshape)); } TEST_F(LiteralUtilTest, ReshapeR4) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // F32[1x3x4x2] - auto expected = LiteralUtil::CreateR3WithLayout({ + auto expected = Literal::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); + EXPECT_TRUE(expected->Equal(*reshape)); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0minor_); // F32[1x3x4x2] - auto expected = LiteralUtil::CreateR3WithLayout({ + auto expected = Literal::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); + EXPECT_TRUE(expected->Equal(*reshape)); } TEST_F(LiteralUtilTest, TransposeR0) { - auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{}); - EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape)); + auto original = Literal::CreateR0(1.7f); + auto reshape = original->Transpose(/*permutation=*/{}); + EXPECT_TRUE(original->Equal(*reshape)); } TEST_F(LiteralUtilTest, TransposeR4) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4({{ + auto original = Literal::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}); // clang-format on - auto reshape = - LiteralUtil::Transpose(*original, /*permutation=*/{2, 3, 0, 1}); - - LiteralUtil::EachCell( - *reshape, [&](tensorflow::gtl::ArraySlice indices, float value) { - EXPECT_EQ(value, - LiteralUtil::Get(*original, {indices[2], indices[3], - indices[0], indices[1]})); + auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); + + reshape->EachCell( + [&](tensorflow::gtl::ArraySlice indices, float value) { + EXPECT_EQ(value, original->Get( + {indices[2], indices[3], indices[0], indices[1]})); }); } TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. - auto dim0minor_relaid_to_dim0major = LiteralUtil::Relayout( - *literal_r4_2x2x3x3_dim0minor_, layout_r4_dim0major_); - EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0major_, - *dim0minor_relaid_to_dim0major)); - - auto dim0major_relaid_to_dim0minor = LiteralUtil::Relayout( - *literal_r4_2x2x3x3_dim0major_, layout_r4_dim0minor_); - EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0minor_, - *dim0major_relaid_to_dim0minor)); + auto dim0minor_relaid_to_dim0major = + literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); + EXPECT_TRUE( + literal_r4_2x2x3x3_dim0major_->Equal(*dim0minor_relaid_to_dim0major)); + + auto dim0major_relaid_to_dim0minor = + literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); + EXPECT_TRUE( + literal_r4_2x2x3x3_dim0minor_->Equal(*dim0major_relaid_to_dim0minor)); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. - auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( - {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); + auto mat_dim0minor = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, + layout_r2_dim0minor_); EXPECT_EQ(mat_dim0minor->s32s_size(), 6); EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. - auto relaid_mat_to_dim0major = - LiteralUtil::Relayout(*mat_dim0minor, layout_r2_dim0major_); + auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). - auto mat_dim0major = LiteralUtil::CreateR2WithLayout( - {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); + auto mat_dim0major = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, + layout_r2_dim0major_); EXPECT_EQ(mat_dim0major->s32s_size(), 6); EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. - auto relaid_mat_to_dim0minor = - LiteralUtil::Relayout(*mat_dim0major, layout_r2_dim0minor_); + auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); } @@ -578,8 +566,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { {10, 11, 12}, }, }); // clang-format on - auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( - arr3d, layout_r3_dim0minor_); + auto lit_dim0minor = + Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0minor_); EXPECT_EQ(lit_dim0minor->s32s_size(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; @@ -587,122 +575,120 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. - auto relaid_lit_to_dim0major = - LiteralUtil::Relayout(*lit_dim0minor, layout_r3_dim0major_); + auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; EXPECT_THAT(relaid_lit_to_dim0major->s32s(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). - auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( - arr3d, layout_r3_dim0major_); + auto lit_dim0major = + Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0major_); EXPECT_EQ(lit_dim0major->s32s_size(), 12); EXPECT_THAT(lit_dim0major->s32s(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. - auto relaid_lit_to_dim0minor = - LiteralUtil::Relayout(*lit_dim0major, layout_r3_dim0minor_); + auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); EXPECT_THAT(relaid_lit_to_dim0minor->s32s(), testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { - auto input = LiteralUtil::CreateR0(1); - auto result = LiteralUtil::Slice(*input, {}, {}); - EXPECT_TRUE(LiteralUtil::Equal(*input, *result)); + auto input = Literal::CreateR0(1); + auto result = input->Slice({}, {}); + EXPECT_TRUE(input->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR1F32) { - auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); - auto result = LiteralUtil::Slice(*input, {3}, {4}); - auto expected = LiteralUtil::CreateR1({4.0}); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *result)); + auto input = Literal::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); + auto result = input->Slice({3}, {4}); + auto expected = Literal::CreateR1({4.0}); + EXPECT_TRUE(expected->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR2U32) { - auto input_3x4 = LiteralUtil::CreateR2( - {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto result = LiteralUtil::Slice(*input_3x4, {0, 2}, {2, 4}); - auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *result)); + auto input_3x4 = + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto result = input_3x4->Slice({0, 2}, {2, 4}); + auto expected = Literal::CreateR2({{3, 4}, {7, 8}}); + EXPECT_TRUE(expected->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR3U32Full) { - auto input_2x3x2 = LiteralUtil::CreateR3( + auto input_2x3x2 = Literal::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - auto result = LiteralUtil::Slice(*input_2x3x2, {0, 0, 0}, {2, 3, 2}); - EXPECT_TRUE(LiteralUtil::Equal(*input_2x3x2, *result)); + auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); + EXPECT_TRUE(input_2x3x2->Equal(*result)); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output; - LiteralUtil::PopulateR1({77}, &output); - auto expected = LiteralUtil::CreateR1({77}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateR1({77}); + auto expected = Literal::CreateR1({77}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateR2U64) { Literal output; - LiteralUtil::PopulateR1({{77, 88}}, &output); - auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateR1({{77, 88}}); + auto expected = Literal::CreateR1({{77, 88}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; - LiteralUtil::PopulateWithValue(2.5f, {}, &output); - auto expected = LiteralUtil::CreateR0(2.5f); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(2.5f, {}); + auto expected = Literal::CreateR0(2.5f); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output; - LiteralUtil::PopulateWithValue(-7, {3}, &output); - auto expected = LiteralUtil::CreateR1({-7, -7, -7}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(-7, {3}); + auto expected = Literal::CreateR1({-7, -7, -7}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output; - LiteralUtil::PopulateWithValue(42, {2, 2}, &output); - auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(42, {2, 2}); + auto expected = Literal::CreateR2({{42, 42}, {42, 42}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output; half h(0.25f); - LiteralUtil::PopulateWithValue(h, {}, &output); - auto expected = LiteralUtil::CreateR0(h); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {}); + auto expected = Literal::CreateR0(h); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { Literal output; half h(0.5f); - LiteralUtil::PopulateWithValue(h, {3}, &output); - auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {3}); + auto expected = Literal::CreateR1({h, h, h}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { Literal output; half h(2.0f); - LiteralUtil::PopulateWithValue(h, {2, 2}, &output); - auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {2, 2}); + auto expected = Literal::CreateR2({{h, h}, {h, h}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, ReplicateR2U32) { - auto input = LiteralUtil::CreateR2( - {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto output = LiteralUtil::Replicate(*input, 3); - auto expected = LiteralUtil::CreateR3( + auto input = + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto output = input->Replicate(3); + auto expected = Literal::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_TRUE(LiteralUtil::Equal(*output, *expected)); + EXPECT_TRUE(output->Equal(*expected)); } TEST_F(LiteralUtilTest, Copy) { @@ -712,13 +698,13 @@ TEST_F(LiteralUtilTest, Copy) { for (const auto& layout : layouts) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), dimensions, layout); - auto blank = LiteralUtil::CreateFromShape(shape); - auto source = LiteralUtil::CreateFromShape(shape); + auto blank = Literal::CreateFromShape(shape); + auto source = Literal::CreateFromShape(shape); const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; auto init_proc = [&](const std::vector& indexes) { - LiteralUtil::Set(source.get(), indexes, ++seqnr); + source->Set(indexes, ++seqnr); return true; }; @@ -729,8 +715,7 @@ TEST_F(LiteralUtilTest, Copy) { const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(LiteralUtil::Copy(*source, src_base, blank.get(), dest_base, - copy_size)); + TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; @@ -741,9 +726,8 @@ TEST_F(LiteralUtilTest, Copy) { std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, blank_indexes.begin(), std::plus()); - auto bval = LiteralUtil::Get(*blank, blank_indexes); - matched = (bval != 0 && - bval == LiteralUtil::Get(*source, source_indexes)); + auto bval = blank->Get(blank_indexes); + matched = (bval != 0 && bval == source->Get(source_indexes)); return matched; }; ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, @@ -753,25 +737,25 @@ TEST_F(LiteralUtilTest, Copy) { } TEST_F(LiteralUtilTest, CopyScalars) { - auto zero = LiteralUtil::CreateR0(0); - auto nine = LiteralUtil::CreateR0(9); - TF_EXPECT_OK(LiteralUtil::Copy(*nine, {}, zero.get(), {}, {})); - EXPECT_TRUE(LiteralUtil::Equal(*zero, *nine)); - - auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(LiteralUtil::Copy(*vect, {5}, zero.get(), {}, {})); - EXPECT_EQ(LiteralUtil::Get(*zero, {}), 17); - TF_EXPECT_OK(LiteralUtil::Copy(*zero, {}, vect.get(), {4}, {})); - EXPECT_EQ(LiteralUtil::Get(*vect, {4}), 17); + auto zero = Literal::CreateR0(0); + auto nine = Literal::CreateR0(9); + TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {})); + EXPECT_TRUE(zero->Equal(*nine)); + + auto vect = Literal::CreateR1({3, 4, 9, 12, 5, 17, 21}); + TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {})); + EXPECT_EQ(zero->Get({}), 17); + TF_EXPECT_OK(vect->Copy(*zero, {}, {4}, {})); + EXPECT_EQ(vect->Get({4}), 17); } TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent - auto m1 = LiteralUtil::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); Literal* l1 = m1.get(); - const char* d1 = static_cast(LiteralUtil::InternalData(*l1)); + const char* d1 = static_cast(l1->InternalData()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -780,14 +764,13 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d1[5], 0); EXPECT_EQ(d1[6], 0); EXPECT_EQ(d1[7], 0); - EXPECT_EQ(LiteralUtil::InternalData(*l1), - LiteralUtil::MutableInternalData(l1)); + EXPECT_EQ(l1->InternalData(), l1->MutableInternalData()); half h1(1.0f); half h2(2.0f); - auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); + auto m2 = Literal::CreateR2({{h1, h2}, {h2, h1}}); Literal* l2 = m2.get(); - const char* d2 = static_cast(LiteralUtil::InternalData(*l2)); + const char* d2 = static_cast(l2->InternalData()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -796,8 +779,7 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d2[5], 0x40); EXPECT_EQ(d2[6], 0); EXPECT_EQ(d2[7], 0x3C); - EXPECT_EQ(LiteralUtil::InternalData(*l2), - LiteralUtil::MutableInternalData(l2)); + EXPECT_EQ(l2->InternalData(), l2->MutableInternalData()); } TEST_F(LiteralUtilTest, Populate) { @@ -818,19 +800,19 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = LiteralUtil::CreateFromShape(shape); + auto literal = Literal::CreateFromShape(shape); auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return LiteralUtil::LinearIndex(*literal, indexes) + 17; + return literal->LinearIndex(indexes) + 17; }; - TF_EXPECT_OK(LiteralUtil::Populate(literal.get(), generator)); + TF_EXPECT_OK(literal->Populate(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](const std::vector& indexes) { - auto value = LiteralUtil::Get(*literal, indexes); + auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; @@ -842,65 +824,66 @@ TEST_F(LiteralUtilTest, Populate) { TEST_F(LiteralUtilTest, ConvertR4) { // clang-format off - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); - auto expected = LiteralUtil::CreateR4WithLayout({{ + auto expected = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - auto converted = LiteralUtil::Convert(*original); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, + original->Convert(U32)); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted)); + EXPECT_TRUE(expected->Equal(*converted)); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format off - auto s8 = LiteralUtil::CreateR4WithLayout({{ + auto s8 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s32 = LiteralUtil::CreateR4WithLayout({{ + auto s32 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u32 = LiteralUtil::CreateR4WithLayout({{ + auto u32 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s64 = LiteralUtil::CreateR4WithLayout({{ + auto s64 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u64 = LiteralUtil::CreateR4WithLayout({{ + auto u64 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto pred = LiteralUtil::CreateR4WithLayout({{ + auto pred = Literal::CreateR4WithLayout({{ {{true, false, true, false}, {false, true, false, true}}, {{false, true, false, true}, {true, false, true, false}}, {{true, false, true, false}, {false, true, false, true}}, }}, layout_r4_dim0major_); - auto int32_pred = LiteralUtil::CreateR4WithLayout({{ + auto int32_pred = Literal::CreateR4WithLayout({{ {{1, 0, 1, 0}, {0, 1, 0, 1}}, {{0, 1, 0, 1}, {1, 0, 1, 0}}, {{1, 0, 1, 0}, {0, 1, 0, 1}}, }}, layout_r4_dim0major_); - auto f32 = LiteralUtil::CreateR4WithLayout({{ + auto f32 = Literal::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - auto f64 = LiteralUtil::CreateR4WithLayout({{ + auto f64 = Literal::CreateR4WithLayout({{ {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, @@ -908,40 +891,40 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format on std::unique_ptr conv; - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, U32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *u32)); + conv = s8->Convert(U32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*u32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = s8->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, U64).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *u64)); + conv = s8->Convert(U64).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*u64)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, S64).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s64)); + conv = s8->Convert(S64).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s64)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, PRED).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *pred)); + conv = s8->Convert(PRED).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*pred)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*pred, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *int32_pred)); + conv = pred->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*int32_pred)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*f32, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = f32->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*f64, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = f64->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s32, F32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *f32)); + conv = s32->Convert(F32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*f32)); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, TUPLE).status().code(), + EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, F16).status().code(), + EXPECT_EQ(s32->Convert(F16).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, S16).status().code(), + EXPECT_EQ(s32->Convert(S16).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, U16).status().code(), + EXPECT_EQ(s32->Convert(U16).status().code(), tensorflow::error::INVALID_ARGUMENT); } @@ -996,9 +979,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { half h1(1.0f); half h2(2.0f); - const char half_vals[8] = { - 0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C - }; + const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C}; LiteralProto p; p.mutable_shape()->set_element_type(F16); p.mutable_shape()->clear_dimensions(); @@ -1006,7 +987,6 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { p.clear_f16s(); p.set_f16s(half_vals, 8); - Literal literal(p); ASSERT_EQ(4, literal.f16s_size()); ASSERT_EQ(h1, literal.f16s(0)); @@ -1022,6 +1002,5 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index d488830a6cd7b07ccb8de237121ab0693bd73a0f..70e0f5a74711c8ceef1b6d4225141aa1cc9c6219 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -58,8 +58,7 @@ StatusOr> PackedLiteralReader::Read( } int64 elements = ShapeUtil::ElementsIn(shape); - LiteralUtil::Resize(elements, std::numeric_limits::quiet_NaN(), - result.get()); + result->Resize(elements, std::numeric_limits::quiet_NaN()); std::vector* field = result->mutable_f32s(); char* data = tensorflow::bit_cast(field->data()); uint64 bytes = elements * sizeof(float); diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index e8de559a5ef9e69864abab21c99887d40cfd378a..7ef5c6d916f52f89a58e107c9526ee312f7369d3 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -86,6 +86,53 @@ namespace xla { return result; } +/* static */ std::unique_ptr> ReferenceUtil::ConvArray3D( + const Array3D& lhs, const Array3D& rhs, int64 kernel_stride, + Padding padding) { + return ConvArray3DGeneralDimensionsDilated( + lhs, rhs, kernel_stride, padding, 1, 1, + ComputationBuilder::CreateDefaultConvDimensionNumbers(1)); +} + +/*static*/ std::unique_ptr> +ReferenceUtil::ConvArray3DGeneralDimensionsDilated( + const Array3D& lhs, const Array3D& rhs, int64 kernel_stride, + Padding padding, int64 lhs_dilation, int64 rhs_dilation, + const ConvolutionDimensionNumbers& dnums) { + CHECK_EQ(dnums.spatial_dimensions_size(), 1); + // Reuse the code for Array4D-convolution by extending the 3D input into a 4D + // array by adding a fourth dummy dimension of size 1 without stride, padding + // and dilation. + Array4D a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1); + a4dlhs.Each( + [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); + }); + Array4D a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1); + a4drhs.Each( + [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); + }); + // Add a second dummy spatial dimensions. + ConvolutionDimensionNumbers dnums2d = dnums; + dnums2d.add_spatial_dimensions(3); + dnums2d.add_kernel_spatial_dimensions(3); + std::unique_ptr> convr4 = ConvArray4DGeneralDimensionsDilated( + a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, + {rhs_dilation, 1}, dnums2d); + + auto convr3 = MakeUnique>(convr4->planes(), convr4->depth(), + convr4->height()); + convr4->Each( + [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; + }); + return convr3; +} + /* static */ std::unique_ptr> ReferenceUtil::ConvArray4D( const Array4D& lhs, const Array4D& rhs, std::pair kernel_stride, Padding padding) { @@ -135,6 +182,49 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); } +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + std::vector dim_lengths{static_cast(operand.size())}; + auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + + std::vector window_counts(window.size(), 0); + std::vector pad_low(window.size(), 0); + for (int64 i = 0; i < window.size(); ++i) { + window_counts[i] = + WindowCount(dim_lengths[i], window[i], stride[i], padding); + pad_low[i] = padding_both[i].first; + } + auto result = MakeUnique>(window_counts[0]); + + // Do a full 1D reduce window. + for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { + int64 i0_base = i0 * stride[0] - pad_low[0]; + + float val = init; + for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { + if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) { + val = reduce_func(val, operand[i0_base + i0_win]); + } + } + (*result)[i0] = val; + } + return result; +} + +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DAdd( + const tensorflow::gtl::ArraySlice& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; + return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, + padding); +} + /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( const Array2D& operand, float init, const tensorflow::gtl::ArraySlice& window, @@ -252,6 +342,20 @@ ReferenceUtil::ReduceWindow4DGeneric( padding); } +/* static */ std::unique_ptr> ReferenceUtil::BatchNorm4D( + const Array4D& input, const Array4D& mean, + const Array4D& var, const Array4D& scale, + const Array4D& offset, float epsilon) { + auto normalized = + *MapArray4D(input, mean, [](float a, float b) { return a - b; }); + normalized = *MapArray4D(normalized, var, [&](float a, float b) { + return a / std::sqrt(b + epsilon); + }); + normalized = + *MapArray4D(normalized, scale, [](float a, float b) { return a * b; }); + return MapArray4D(normalized, offset, [](float a, float b) { return a + b; }); +} + /* static */ std::unique_ptr> ReferenceUtil::SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, @@ -439,21 +543,21 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( // Lambda to access the rhs operand at the given 4D index. height_over_dky // should be equal to height / dky, and width_over_dkx should be equal to // width / dkx. (This is an optimization to avoid doing divisions.) - const auto rhs_element = [&]( - int64 kernel_output_feature, int64 kernel_input_feature, int64 height, - int64 width, int64 height_over_dky, int64 width_over_dkx) { - DCHECK_EQ(height % dky, 0); - DCHECK_EQ(width % dkx, 0); - DCHECK_EQ(height / dky, height_over_dky); - DCHECK_EQ(width / dkx, width_over_dkx); - - std::array index; - index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; - index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; - index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; - index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; - return rhs(index[0], index[1], index[2], index[3]); - }; + const auto rhs_element = + [&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height, + int64 width, int64 height_over_dky, int64 width_over_dkx) { + DCHECK_EQ(height % dky, 0); + DCHECK_EQ(width % dkx, 0); + DCHECK_EQ(height / dky, height_over_dky); + DCHECK_EQ(width / dkx, width_over_dkx); + + std::array index; + index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; + index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; + index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; + index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; + return rhs(index[0], index[1], index[2], index[3]); + }; // Lambda to access the result data at the given 4D index. const auto result_element = [&](int64 batch, int64 kernel_output_feature, @@ -491,13 +595,37 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } } } + if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 || + iz == 0) { + LOG(INFO) << "Output will be trivially empty because one of these " + "dimensions is 0: samples: " + << samples << " kx: " << kx << " ky: " << ky << " ox: " << ox + << " oy: " << oy << " oz: " << oz << " iz: " << iz; + return result; + } + bool trivial = true; + auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice indices, + float value) { + if (value != 0.0) { + trivial = false; + } + }; + lhs.Each(check_trivial); + if (trivial) { + LOG(FATAL) << "LHS is all 0.0."; + } + trivial = true; + rhs.Each(check_trivial); + if (trivial) { + LOG(FATAL) << "RHS is all 0.0."; + } return result; } /* static */ std::unique_ptr> ReferenceUtil::ReduceToColArray2D( const Array2D& matrix, float init, - std::function reduce_function) { + const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); auto result = MakeUnique>(); @@ -514,7 +642,7 @@ ReferenceUtil::ReduceToColArray2D( /* static */ std::unique_ptr> ReferenceUtil::ReduceToRowArray2D( const Array2D& matrix, float init, - std::function reduce_function) { + const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); auto result = MakeUnique>(); @@ -531,7 +659,7 @@ ReferenceUtil::ReduceToRowArray2D( /*static*/ std::vector ReferenceUtil::Reduce4DTo1D( const Array4D& array, float init, tensorflow::gtl::ArraySlice dims, - std::function reduce_function) { + const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); const std::set dim_set(dims.begin(), dims.end()); @@ -566,10 +694,42 @@ ReferenceUtil::ReduceToRowArray2D( return result; } +/* static */ std::unique_ptr> ReferenceUtil::Broadcast1DTo4D( + const std::vector& array, const std::vector& bounds, + int64 broadcast_from_dim) { + auto result = + MakeUnique>(bounds[0], bounds[1], bounds[2], bounds[3]); + for (int64 i = 0; i < result->n1(); ++i) { + for (int64 j = 0; j < result->n2(); ++j) { + for (int64 k = 0; k < result->n3(); ++k) { + for (int64 l = 0; l < result->n4(); ++l) { + switch (broadcast_from_dim) { + case 0: + (*result)(i, j, k, l) = array[i]; + break; + case 1: + (*result)(i, j, k, l) = array[j]; + break; + case 2: + (*result)(i, j, k, l) = array[k]; + break; + case 3: + (*result)(i, j, k, l) = array[l]; + break; + default: + break; + } + } + } + } + } + return result; +} + /* static */ std::unique_ptr> ReferenceUtil::Reduce3DTo2D( const Array3D& array, float init, tensorflow::gtl::ArraySlice dims, - std::function reduce_function) { + const std::function& reduce_function) { CHECK_EQ(dims.size(), 1); int64 rows = dims[0] == 0 ? array.n2() : array.n1(); int64 cols = dims[0] == 2 ? array.n2() : array.n3(); @@ -665,6 +825,61 @@ ReferenceUtil::ReduceToRowArray2D( return result; } +/* static */ Array3D ReferenceUtil::PadArray3D( + const Array3D& operand, const PaddingConfig& padding, + const float pad) { + CHECK_EQ(padding.dimensions_size(), 3); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3()}; + std::vector pad_low(3); + std::vector pad_high(3); + std::vector pad_interior(3); + std::vector output_bounds(3); + for (int64 i = 0; i < 3; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, pad_low[i]); + CHECK_LE(0, pad_high[i]); + CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); + std::vector indices = {0, 0, 0}; + for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { + for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { + for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { + float* value = &result(indices[0], indices[1], indices[2]); + bool value_padded = false; + for (int i = 0; i < 3; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + value_padded = true; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + value_padded = true; + } + } + if (value_padded) { + continue; + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); + } + } + } + return result; +} + /* static */ Array4D ReferenceUtil::PadArray4D( const Array4D& operand, const PaddingConfig& padding, const float pad) { diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index f58f0bdc9f51dff62c10dda4aba7aac03e689ce7..2da17307817858eea60e868f4be1ab8138784385 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" @@ -73,6 +74,20 @@ class ReferenceUtil { std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums); + // Returns the result of a convolution `lhs rhs`, with the default + // convolution dimension numbers returned from + // ComputationBuilder::CreateDefaultConvDimensionNumbers(). + static std::unique_ptr> ConvArray3D(const Array3D& lhs, + const Array3D& rhs, + int64 kernel_stride, + Padding padding); + + // Returns the result of a convolution `lhs rhs`. + static std::unique_ptr> ConvArray3DGeneralDimensionsDilated( + const Array3D& lhs, const Array3D& rhs, int64 kernel_stride, + Padding padding, int64 lhs_dilation, int64 rhs_dilation, + const ConvolutionDimensionNumbers& dnums); + // Returns the result of a separable convolution with the given parameters. // kernel_stride and padding applies to the depthwise convolution during // the separable convolution. pointwise_weights.depth() must be equal to @@ -87,21 +102,21 @@ class ReferenceUtil { // to apply for each reduction step. static std::unique_ptr> ReduceToColArray2D( const Array2D& matrix, float init, - std::function reduce_function); + const std::function& reduce_function); // Returns the result of reducing a matrix to a row vector. init is the // initial value for the reduce operation, and reduce_function is the function // to apply for each reduction step. static std::unique_ptr> ReduceToRowArray2D( const Array2D& matrix, float init, - std::function reduce_function); + const std::function& reduce_function); // Performs a R2=>R1 reduction by reducing away the dimension specified in // 'dimension_to_reduce'. template static std::vector ReduceR2ToR1(const Array2D& input, int dimension_to_reduce, T init, - std::function freduce) { + const std::function& freduce) { std::vector result(dimension_to_reduce == 0 ? input.n2() : input.n1(), init); for (int i0 = 0; i0 < input.n1(); ++i0) { @@ -118,14 +133,19 @@ class ReferenceUtil { static std::vector Reduce4DTo1D( const Array4D& array, float init, tensorflow::gtl::ArraySlice dims, - std::function reduce_function); + const std::function& reduce_function); + + // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`. + static std::unique_ptr> Broadcast1DTo4D( + const std::vector& array, const std::vector& bounds, + int64 broadcast_from_dim); // Returns the result of reducing the 3D array to a 2D array, reducing away // the dimensions specified in dims. static std::unique_ptr> Reduce3DTo2D( const Array3D& array, float init, tensorflow::gtl::ArraySlice dims, - std::function reduce_function); + const std::function& reduce_function); // Applies map_function to each element in the input (2D array) and returns // the result. @@ -144,19 +164,26 @@ class ReferenceUtil { static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride, Padding padding); - // Performs a 2D window reduction with Add as the function to apply. + // Windowed reductions with Add as the function to apply. + static std::unique_ptr> ReduceWindow1DAdd( + const tensorflow::gtl::ArraySlice& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); static std::unique_ptr> ReduceWindow2DAdd( const Array2D& operand, float init, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); - - // Performs a 4D window reduction with Add as the function to apply. static std::unique_ptr> ReduceWindow4DAdd( const Array4D& operand, float init, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); - // Performs a 4D window reduction with a generic reduce function. + // Windowed reductions with a generic reduce function. + static std::unique_ptr> ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, @@ -169,6 +196,12 @@ class ReferenceUtil { const tensorflow::gtl::ArraySlice& stride, const tensorflow::gtl::ArraySlice>& padding); + // Batch normalize data. + static std::unique_ptr> BatchNorm4D( + const Array4D& input, const Array4D& mean, + const Array4D& var, const Array4D& scale, + const Array4D& offset, float epsilon); + // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. static std::unique_ptr> SelectAndScatter4DGePlus( @@ -283,48 +316,56 @@ class ReferenceUtil { return result; } - // Slices the input array given starting indices in each dimension and limit - // indices in each dimension. + // Slices the input array given starting indices, limit indices, and strides + // in each dimension. template static std::unique_ptr> Slice2D(const Array2D& input, std::array starts, - std::array limits) { + std::array limits, + std::array strides) { CHECK_LE(starts[0], input.n1()); CHECK_LE(starts[1], input.n2()); CHECK_LE(limits[0], input.n1()); CHECK_LE(limits[1], input.n2()); + CHECK_GE(strides[0], 1); + CHECK_GE(strides[1], 1); auto result = - MakeUnique>(limits[0] - starts[0], limits[1] - starts[1]); + MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { - (*result)(i0, i1) = input(starts[0] + i0, starts[1] + i1); + (*result)(i0, i1) = + input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1]); } } return result; } template - static std::unique_ptr> Slice4D(const Array4D& input, - std::array starts, - std::array limits) { + static std::unique_ptr> Slice3D(const Array3D& input, + std::array starts, + std::array limits, + std::array strides) { CHECK_LE(starts[0], input.n1()); CHECK_LE(starts[1], input.n2()); CHECK_LE(starts[2], input.n3()); - CHECK_LE(starts[3], input.n4()); CHECK_LE(limits[0], input.n1()); CHECK_LE(limits[1], input.n2()); CHECK_LE(limits[2], input.n3()); - CHECK_LE(limits[3], input.n4()); + CHECK_GE(strides[0], 1); + CHECK_GE(strides[1], 1); + CHECK_GE(strides[2], 1); auto result = - MakeUnique>(limits[0] - starts[0], limits[1] - starts[1], - limits[2] - starts[2], limits[3] - starts[3]); + MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2])); + for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { - for (int64 i3 = 0; i3 < result->n4(); ++i3) { - (*result)(i0, i1, i2, i3) = input(starts[0] + i0, starts[1] + i1, - starts[2] + i2, starts[3] + i3); - } + (*result)(i0, i1, i2) = + input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1], + starts[2] + i2 * strides[2]); } } } @@ -332,22 +373,35 @@ class ReferenceUtil { } template - static std::unique_ptr> Slice3D(const Array3D& input, - std::array starts, - std::array limits) { + static std::unique_ptr> Slice4D(const Array4D& input, + std::array starts, + std::array limits, + std::array strides) { CHECK_LE(starts[0], input.n1()); CHECK_LE(starts[1], input.n2()); CHECK_LE(starts[2], input.n3()); + CHECK_LE(starts[3], input.n4()); CHECK_LE(limits[0], input.n1()); CHECK_LE(limits[1], input.n2()); CHECK_LE(limits[2], input.n3()); - auto result = MakeUnique>( - limits[0] - starts[0], limits[1] - starts[1], limits[2] - starts[2]); + CHECK_LE(limits[3], input.n4()); + CHECK_GE(strides[0], 1); + CHECK_GE(strides[1], 1); + CHECK_GE(strides[2], 1); + CHECK_GE(strides[3], 1); + auto result = + MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2]), + CeilOfRatio(limits[3] - starts[3], strides[3])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { - (*result)(i0, i1, i2) = - input(starts[0] + i0, starts[1] + i1, starts[2] + i2); + for (int64 i3 = 0; i3 < result->n4(); ++i3) { + (*result)(i0, i1, i2, i3) = + input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1], + starts[2] + i2 * strides[2], starts[3] + i3 * strides[3]); + } } } } @@ -396,11 +450,51 @@ class ReferenceUtil { return result; } + // Applies map_function to each pair of elements in the input lhs and rhs + // (4D array) and returns the result. + template + static std::unique_ptr> MapArray4D(const Array4D& lhs, + const Array4D& rhs, + F&& map_function) { + return MapWithIndexArray4D( + lhs, rhs, [&](float lhs, float rhs, int64, int64, int64, int64) { + return map_function(lhs, rhs); + }); + } + + // Applies map_function to each pair of element in lhs and rhs (4D array) and + // returns the result. + // (plane, depth, height, width) index of each element is also provided as + // arguments to map_function. + template + static std::unique_ptr> MapWithIndexArray4D( + const Array4D& lhs, const Array4D& rhs, F&& map_function) { + auto result = MakeUnique>(lhs.planes(), lhs.depth(), + lhs.height(), lhs.width()); + for (int64 plane = 0; plane < lhs.planes(); ++plane) { + for (int64 depth = 0; depth < lhs.depth(); ++depth) { + for (int64 height = 0; height < lhs.height(); ++height) { + for (int64 width = 0; width < lhs.width(); ++width) { + (*result)(plane, depth, height, width) = map_function( + lhs(plane, depth, height, width), + rhs(plane, depth, height, width), plane, depth, height, width); + } + } + } + } + return result; + } + // Returns the result of a 2D pad on an input matrix. static std::unique_ptr> PadArray2D( const Array2D& operand, const PaddingConfig& padding, const float pad); + // Returns the result of a 3D pad on an input matrix. + static Array3D PadArray3D(const Array3D& operand, + const PaddingConfig& padding, + const float pad); + // Returns the result of a 4D pad on an input array. static Array4D PadArray4D(const Array4D& operand, const PaddingConfig& padding, diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index f839ac019df07c5c5e07eed856ea55463bb3efae..35b5e8cd52ab0ec21a4bd2df3e9fa0538ae60816 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -52,7 +53,7 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -62,7 +63,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -70,7 +71,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto actual_literal = LiteralUtil::CreateR1(*result); + auto actual_literal = Literal::CreateR1(*result); LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -78,7 +79,7 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) { TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto actual_literal = LiteralUtil::CreateR1(*result); + auto actual_literal = Literal::CreateR1(*result); LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -86,7 +87,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -96,7 +97,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -107,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = Literal::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); @@ -124,7 +125,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = Literal::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); @@ -132,6 +133,101 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, SliceArray2D) { + auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); + auto actual_literal = Literal::CreateR2FromArray2D(*result); + + LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceStridedArray2D) { + auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); + auto actual_literal = Literal::CreateR2FromArray2D(*result); + + LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceArray3D) { + Array3D input(2, 3, 4); + input.FillIota(0); + + auto result = + ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}}); + auto actual_literal = Literal::CreateR3FromArray3D(*result); + + LiteralTestUtil::ExpectR3Near( + {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceStridedArray3D) { + Array3D input(2, 3, 4); + input.FillIota(0); + + auto result = + ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}}); + auto actual_literal = Literal::CreateR3FromArray3D(*result); + + LiteralTestUtil::ExpectR3Near( + {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceArray4D) { + Array4D input(2, 3, 4, 5); + input.FillIota(0); + + auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}}, + {{1, 1, 1, 1}}); + auto actual_literal = Literal::CreateR4FromArray4D(*result); + + LiteralTestUtil::ExpectR4Near( + {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceStridedArray4D) { + Array4D input(2, 3, 4, 5); + input.FillIota(0); + + auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}}, + {{1, 2, 2, 2}}); + auto actual_literal = Literal::CreateR4FromArray4D(*result); + + LiteralTestUtil::ExpectR4Near( + {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, + {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { + Array3D input = {{{1, 2, 3, 4}}}; + Array3D weights = {{{5, 6}}}; + std::unique_ptr> actual = + ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kSame); + Array3D expected = {{{17, 28, 39, 20}}}; + + auto actual_literal = Literal::CreateR3FromArray3D(*actual); + + LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { + Array3D input = {{{1, 2, 3, 4}}}; + Array3D weights = {{{5, 6}}}; + std::unique_ptr> actual = + ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kValid); + Array3D expected = {{{17, 28, 39}}}; + + auto actual_literal = Literal::CreateR3FromArray3D(*actual); + + LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + ErrorSpec(0.0001)); +} + TEST_F(ReferenceUtilTest, ConvWithSamePadding) { Array4D input(1, 1, 4, 4); // clang-format off @@ -161,7 +257,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { })); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -195,7 +291,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { })); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -247,7 +343,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { }}); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -296,7 +392,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { Array4D expected({{{{2514, 2685}}}}); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -309,7 +405,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { auto actual = ReferenceUtil::ApplyElementwise2D( [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); + auto actual_literal = Literal::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, *actual_literal, ErrorSpec(0.0001)); } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0687368b83db343cfa15da969b9f4d9d1a821078..89ebdb0e26a4c03440d771c5867c5dea880311cf 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -24,9 +24,7 @@ xla_proto_library( xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], - deps = [ - "//tensorflow/compiler/xla:xla_data_proto", - ], + deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) # Filegroup used to collect source files for dependency checking. @@ -88,11 +86,13 @@ cc_library( deps = [ ":hlo", ":hlo_query", + ":shape_inference", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], @@ -106,12 +106,16 @@ cc_test( ":hlo", ":hlo_evaluator", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test_main", @@ -138,6 +142,7 @@ cc_library( deps = [ ":hlo_module_config", ":hlo_proto", + ":hlo_reachability", ":name_uniquer", ":versioned_computation_handle", "//tensorflow/compiler/xla:literal_util", @@ -155,6 +160,31 @@ cc_library( ], ) +cc_library( + name = "hlo_reachability", + srcs = ["hlo_reachability.cc"], + hdrs = ["hlo_reachability.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_test( + name = "hlo_reachability_test", + srcs = ["hlo_reachability_test.cc"], + deps = [ + ":hlo", + ":hlo_reachability", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "hlo_matchers", testonly = 1, @@ -285,7 +315,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", ], ) @@ -303,7 +333,7 @@ cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", ], @@ -330,6 +360,7 @@ cc_library( hdrs = ["backend.h"], deps = [ ":compiler", + ":computation_placer", ":device_memory_allocator", ":platform_util", ":pool", @@ -338,7 +369,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:backend_flags", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -382,7 +412,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -416,7 +446,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -506,9 +536,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", ], @@ -707,9 +736,10 @@ cc_library( ], deps = [ ":buffer_liveness", + ":heap_simulator", ":hlo", - ":hlo_ordering", ":hlo_proto", + ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -718,7 +748,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -736,6 +765,7 @@ cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", + ":hlo_scheduling", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -748,11 +778,61 @@ cc_test( ], ) +cc_library( + name = "hlo_ordering", + srcs = ["hlo_ordering.cc"], + hdrs = ["hlo_ordering.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_ordering_test", + size = "small", + srcs = ["hlo_ordering_test.cc"], + deps = [ + ":hlo", + ":hlo_ordering", + ":hlo_scheduling", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + +cc_library( + name = "heap_simulator", + srcs = ["heap_simulator.cc"], + hdrs = ["heap_simulator.h"], + deps = [ + ":hlo", + ":hlo_ordering", + ":hlo_proto", + ":liveness_util", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_test( name = "heap_simulator_test", size = "small", srcs = ["heap_simulator_test.cc"], deps = [ + ":heap_simulator", ":hlo", ":hlo_ordering", ":logical_buffer", @@ -765,23 +845,14 @@ cc_test( ], ) -# The hlo_ordering library contains both hlo_ordering and heap_simulator because -# they are mutually dependent. cc_library( - name = "hlo_ordering", - srcs = [ - "heap_simulator.cc", - "hlo_ordering.cc", - ], - hdrs = [ - "heap_simulator.h", - "hlo_ordering.h", - ], + name = "hlo_scheduling", + srcs = ["hlo_scheduling.cc"], + hdrs = ["hlo_scheduling.h"], deps = [ - ":call_graph", + ":heap_simulator", ":hlo", - ":hlo_proto", - ":liveness_util", + ":hlo_ordering", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -794,12 +865,13 @@ cc_library( ) cc_test( - name = "hlo_ordering_test", + name = "hlo_scheduling_test", size = "small", - srcs = ["hlo_ordering_test.cc"], + srcs = ["hlo_scheduling_test.cc"], deps = [ ":hlo", ":hlo_ordering", + ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -841,6 +913,46 @@ cc_test( ], ) +cc_library( + name = "batchnorm_rewriter", + srcs = ["batchnorm_rewriter.cc"], + hdrs = ["batchnorm_rewriter.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_query", + ":shape_inference", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "batchnorm_rewriter_test", + size = "small", + srcs = ["batchnorm_rewriter_test.cc"], + deps = [ + ":batchnorm_rewriter", + ":hlo", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + ], +) + cc_library( name = "algebraic_simplifier", srcs = ["algebraic_simplifier.cc"], @@ -948,6 +1060,38 @@ cc_test( ], ) +cc_library( + name = "computation_placer", + srcs = ["computation_placer.cc"], + hdrs = ["computation_placer.h"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform computation placer registration +) + +cc_library( + name = "human_readable_profile_builder", + srcs = ["human_readable_profile_builder.cc"], + hdrs = ["human_readable_profile_builder.h"], + deps = [ + "//tensorflow/compiler/xla:metric_table_report", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "generic_transfer_manager", srcs = ["generic_transfer_manager.cc"], @@ -1030,12 +1174,8 @@ cc_test( cc_library( name = "hlo_cost_analysis", - srcs = [ - "hlo_cost_analysis.cc", - ], - hdrs = [ - "hlo_cost_analysis.h", - ], + srcs = ["hlo_cost_analysis.cc"], + hdrs = ["hlo_cost_analysis.h"], deps = [ ":hlo", "//tensorflow/compiler/xla:shape_util", @@ -1068,6 +1208,7 @@ cc_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test_main", ], @@ -1080,6 +1221,7 @@ cc_library( deps = [ ":hlo", ":hlo_cost_analysis", + ":human_readable_profile_builder", "//tensorflow/compiler/xla:metric_table_report", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -1137,12 +1279,8 @@ cc_test( cc_library( name = "logical_buffer", - srcs = [ - "logical_buffer.cc", - ], - hdrs = [ - "logical_buffer.h", - ], + srcs = ["logical_buffer.cc"], + hdrs = ["logical_buffer.h"], deps = [ ":hlo", ":hlo_proto", @@ -1155,18 +1293,31 @@ cc_library( ) cc_library( - name = "hlo_dataflow_analysis", - srcs = [ - "hlo_dataflow_analysis.cc", - ], - hdrs = [ - "hlo_dataflow_analysis.h", + name = "hlo_value", + srcs = ["hlo_value.cc"], + hdrs = ["hlo_value.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", ], +) + +cc_library( + name = "hlo_dataflow_analysis", + srcs = ["hlo_dataflow_analysis.cc"], + hdrs = ["hlo_dataflow_analysis.h"], deps = [ ":call_graph", ":hlo", + ":hlo_ordering", + ":hlo_value", ":liveness_util", - "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", @@ -1174,7 +1325,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", ], ) @@ -1201,20 +1351,32 @@ cc_test( ) cc_library( - name = "hlo_alias_analysis", - srcs = [ - "hlo_alias_analysis.cc", - ], - hdrs = [ - "hlo_alias_analysis.h", + name = "hlo_buffer", + srcs = ["hlo_buffer.cc"], + hdrs = ["hlo_buffer.h"], + deps = [ + ":hlo", + ":hlo_value", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", ], +) + +cc_library( + name = "hlo_alias_analysis", + srcs = ["hlo_alias_analysis.cc"], + hdrs = ["hlo_alias_analysis.h"], deps = [ - ":call_graph", ":hlo", + ":hlo_buffer", ":hlo_dataflow_analysis", - ":logical_buffer", - "//tensorflow/compiler/xla:shape_tree", + ":hlo_value", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -1245,12 +1407,8 @@ cc_test( cc_library( name = "tuple_points_to_analysis", - srcs = [ - "tuple_points_to_analysis.cc", - ], - hdrs = [ - "tuple_points_to_analysis.h", - ], + srcs = ["tuple_points_to_analysis.cc"], + hdrs = ["tuple_points_to_analysis.h"], deps = [ ":hlo", ":logical_buffer", @@ -1287,12 +1445,8 @@ cc_test( cc_library( name = "compilation_cache", - srcs = [ - "compilation_cache.cc", - ], - hdrs = [ - "compilation_cache.h", - ], + srcs = ["compilation_cache.cc"], + hdrs = ["compilation_cache.h"], deps = [ ":executable", ":hlo_module_config", @@ -1386,7 +1540,10 @@ cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], hdrs = ["hlo_verifier.h"], - deps = [":hlo_pass"], + deps = [ + ":hlo_pass", + "//tensorflow/core:lib", + ], ) cc_library( @@ -1398,9 +1555,9 @@ cc_library( ":call_graph", ":flatten_call_graph", ":hlo", - ":hlo_cost_analysis", ":hlo_dce", ":hlo_ordering", + ":hlo_scheduling", ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", @@ -1497,8 +1654,8 @@ cc_library( "hlo_pass_pipeline.h", ], deps = [ - ":compiler", ":hlo", + ":hlo_graph_dumper", ":hlo_pass", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1572,10 +1729,8 @@ cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:lib", ], ) @@ -1707,8 +1862,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", ], alwayslink = 1, ) @@ -1777,10 +1933,39 @@ cc_library( ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:status", + ], +) + +cc_library( + name = "reduce_precision_insertion", + srcs = ["reduce_precision_insertion.cc"], + hdrs = ["reduce_precision_insertion.h"], + deps = [ + ":buffer_liveness", + ":hlo", + ":hlo_pass", + ":hlo_pass_pipeline", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", ], ) +cc_test( + name = "reduce_precision_insertion_test", + size = "small", + srcs = ["reduce_precision_insertion_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":reduce_precision_insertion", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 754ac0c68dc025c6d2bde4b40e148e6043f0cf6d..691f9f22964841c1163d161a7c02c2215ba6f066 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -48,7 +48,7 @@ namespace { // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && - LiteralUtil::IsAll(operand->literal(), value); + operand->literal().IsAll(value); } bool IsAll(const HloInstruction* op, int8 value) { @@ -126,10 +126,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; + + Status HandleConvert(HloInstruction* convert) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; @@ -179,11 +181,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs, HloInstruction* rhs) override; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) override; - - Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleMaximum(HloInstruction* maximum) override; + Status HandleMinimum(HloInstruction* minimum) override; // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -334,16 +333,16 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. - if (operand->opcode() == HloOpcode::kCopy) { + if (copy->operand(0)->opcode() == HloOpcode::kCopy) { return ReplaceWithNewInstruction( - copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, - operand->operands()[0])); + copy, HloInstruction::CreateUnary( + copy->shape(), HloOpcode::kCopy, + copy->mutable_operand(0)->mutable_operand(0))); } // All copies can be eliminated (assuming layout constraints are satisified). - ReplaceInstructionIfSameShape(copy, operand); + ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); return Status::OK(); } @@ -415,6 +414,32 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return Status::OK(); } +static HloInstruction* BuildTupleConstant(HloComputation* computation, + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector elems; + elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); + for (const Literal& child : literal.tuple_literals()) { + elems.push_back(BuildTupleConstant(computation, child)); + } + return computation->AddInstruction(HloInstruction::CreateTuple(elems)); + } else { + return computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique(literal))); + } +} + +Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant, + const Literal& literal) { + // Tuple constants aren't directly supported by any backend. Expand them into + // explicit Tuple instructions. + if (ShapeUtil::IsTuple(constant->shape())) { + return ReplaceInstruction(constant, + BuildTupleConstant(computation_, literal)); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, HloInstruction* lhs, HloInstruction* rhs) { @@ -448,6 +473,72 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, subtract)); } + // A/exp(B) => A*exp(-B) + if (rhs->opcode() == HloOpcode::kExp) { + VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); + HloInstruction* negate = + computation_->AddInstruction(HloInstruction::CreateUnary( + divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0))); + HloInstruction* new_exp = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs, new_exp)); + } + + // A/pow(B,C) => A*pow(B,-C) + if (rhs->opcode() == HloOpcode::kPower) { + VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); + HloInstruction* negate = + computation_->AddInstruction(HloInstruction::CreateUnary( + divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(1))); + HloInstruction* new_power = computation_->AddInstruction( + HloInstruction::CreateBinary(divide->shape(), HloOpcode::kPower, + rhs->mutable_operand(0), negate)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs, new_power)); + } + + // Simplifying integral division would produce unexpected results. + if (ShapeUtil::ElementIsIntegral(divide->shape())) { + return Status::OK(); + } + + // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) + if (lhs->opcode() == HloOpcode::kDivide && + rhs->opcode() == HloOpcode::kDivide) { + auto a_times_d = computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(0), + rhs->mutable_operand(1))); + auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), + rhs->mutable_operand(0))); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kDivide, a_times_d, b_times_c)); + } + + // (A / B) / C => A / (B * C) + if (lhs->opcode() == HloOpcode::kDivide) { + auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + return ReplaceWithNewInstruction( + divide, + HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, + lhs->mutable_operand(0), b_times_c)); + } + + // A / (B / C) => (A*C) / B + if (rhs->opcode() == HloOpcode::kDivide) { + auto a_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); + return ReplaceWithNewInstruction( + divide, + HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, + a_times_c, rhs->mutable_operand(0))); + } + return Status::OK(); } @@ -469,7 +560,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, ShapeUtil::HasZeroElements(lhs->shape()) || ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -507,7 +598,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, {0}, add_reduce_computation)); @@ -531,7 +622,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* reduce; if (ShapeUtil::Rank(rhs->shape()) == 1) { auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -571,7 +662,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(dot->shape().element_type(), {lhs->shape().dimensions(0)}), @@ -595,6 +686,16 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { return Status::OK(); } + + // exp(A) * exp(B) => exp(A+B) + if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + auto add = computation_->AddInstruction(HloInstruction::CreateBinary( + multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), + rhs->mutable_operand(0))); + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); + } return Status::OK(); } @@ -606,6 +707,17 @@ Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log, ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { return Status::OK(); } + + // ln(pow(A,B)) => B*ln(A) + if (operand->opcode() == HloOpcode::kPower) { + auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary( + log->shape(), HloOpcode::kLog, operand->mutable_operand(0))); + return ReplaceWithNewInstruction( + log, + HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, + new_log, operand->mutable_operand(1))); + } + return Status::OK(); } @@ -792,12 +904,11 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A conversion to the same element type as the operand is a nop and can be // removed. A conversion of a constant can be simplified by making a new // constant. -Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - PrimitiveType src_type = operand->shape().element_type(); +Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { + PrimitiveType src_type = convert->operand(0)->shape().element_type(); PrimitiveType dest_type = convert->shape().element_type(); if (src_type == dest_type) { - return ReplaceInstruction(convert, operand); + return ReplaceInstruction(convert, convert->mutable_operand(0)); } return Status::OK(); } @@ -878,10 +989,10 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { } // Verify that the slice shape matches the pad shape. - TF_ASSIGN_OR_RETURN(Shape inferred_slice_shape, - ShapeInference::InferSliceShape( - nonzero_pad_shape, start_indices, end_indices, - strides)); + TF_ASSIGN_OR_RETURN( + Shape inferred_slice_shape, + ShapeInference::InferSliceShape(nonzero_pad_shape, start_indices, + end_indices, strides)); TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape())); std::unique_ptr slice = HloInstruction::CreateSlice( @@ -897,8 +1008,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, HloInstruction* rhs) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); if (IsAll(rhs, 0)) { - auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( - LiteralUtil::One(power->shape().element_type()))); + auto one = HloInstruction::CreateConstant( + Literal::One(power->shape().element_type()).CloneToUnique()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -914,6 +1025,14 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, return Status::OK(); } + // pow(exp(A),B) => exp(A*B) + if (lhs->opcode() == HloOpcode::kExp) { + auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( + power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs)); + return ReplaceWithNewInstruction( + power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, + a_times_b)); + } VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); if (IsAll(rhs, 2)) { return ReplaceWithNewInstruction( @@ -923,9 +1042,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { - auto* one = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( - LiteralUtil::One(rhs->shape().element_type())))); + auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::One(rhs->shape().element_type()).CloneToUnique())); return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, one, lhs)); @@ -937,6 +1055,9 @@ StatusOr AlgebraicSimplifierVisitor:: TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* reshape_or_broadcast) { bool changed = false; + if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + return false; + } HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); for (HloInstruction* user : reshape_or_broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { @@ -1008,7 +1129,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // dimension. if (ShapeUtil::HasZeroElements(reshape->shape())) { auto empty_constant = HloInstruction::CreateConstant( - LiteralUtil::CreateFromShape(reshape->shape())); + Literal::CreateFromShape(reshape->shape())); return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); } @@ -1208,8 +1329,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // try to get more fancy about proving equivalence in cases beyond that. if (pad_value->opcode() != HloOpcode::kConstant || reduce_init_value->opcode() != HloOpcode::kConstant || - !LiteralUtil::Equal(pad_value->literal(), - reduce_init_value->literal())) { + !pad_value->literal().Equal(reduce_init_value->literal())) { VLOG(10) << "Not folding pad into reduce-window due to different pad " "values."; return Status::OK(); @@ -1368,9 +1488,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be // invalid. - if (!valid_bitcast_callback_(lhs->shape(), input_shape) || - !valid_bitcast_callback_(rhs->shape(), new_filter_shape) || - !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { + if (!valid_bitcast_callback_(input_shape, lhs->shape()) || + !valid_bitcast_callback_(new_filter_shape, rhs->shape()) || + !valid_bitcast_callback_(convolution_shape, dot_output_shape)) { return Status::OK(); } @@ -1396,9 +1516,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } -Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { // Match the following tree: // min_operand operand // \ / @@ -1429,9 +1547,7 @@ Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { // Match the following tree: // max_operand operand // \ / @@ -1470,6 +1586,9 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { // module, invalidating iteration. std::vector computations; for (auto& comp : module->computations()) { + if (comp->IsFusionComputation()) { + continue; + } computations.push_back(comp.get()); } for (auto& comp : computations) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index f8919f0caad6d7009d371d8a1893ba5c91110122..4295a3227a837ffc8483b3be59994c9e6ac96aec 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -26,12 +26,13 @@ namespace xla { // A pass which performs AlgebraicSimplications. class AlgebraicSimplifier : public HloPassInterface { public: - // Given two shapes, determines if it is valid to bitcast between them after - // considering platform dependent effects on layout like alignment - // restrictions. - // Precondition: the two shapes have layouts, the same number of - // elements and ShapeUtil::ReshapeIsBitcast returns true. - using ValidBitcastCallback = std::function; + // Given shapes 'from_shape' and 'to_shape', determines if it is valid to + // bitcast from 'from_shape' to 'to_shape' after considering platform + // dependent effects on layout like alignment restrictions. Precondition: the + // two shapes have layouts, the same number of elements and + // ShapeUtil::ReshapeIsBitcast returns true. + using ValidBitcastCallback = + std::function; // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. If valid_bitcast_callback diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e4368a7bb25093f70bf78288db2021d36fa7f25a..be71e03e985a285abafc2adf7219b6aca2a775b6 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -55,7 +55,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); @@ -76,7 +76,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); builder.AddInstruction( @@ -99,7 +99,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); + HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))); HloInstruction* bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); builder.AddInstruction( @@ -123,7 +123,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); @@ -138,6 +138,155 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { EXPECT_EQ(root, param0); } +// Test that (A/B)/C is simplified to A/(B*C). +TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* div = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(op::Divide(param0, param1), param2)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Multiply(param1, param2))); +} + +// Test that A/(B/C) is simplified to (A*C)/B. +TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* div = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Divide(param1, param2))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(op::Multiply(param0, param2), param1)); +} + +// Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). +TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* param3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, r0f32, "param3")); + HloInstruction* div0 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1)); + HloInstruction* div1 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param2, param3)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div0, div1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT( + computation->root_instruction(), + op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT( + computation->root_instruction(), + op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); +} + +// Test that A/exp(B) is simplified to A*exp(-B). +TEST_F(AlgebraicSimplifierTest, DivOfExp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Exp(param1))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(param0, op::Exp(op::Negate(param1)))); +} + +// Test that A/pow(B,C) is simplified to A*pow(B,-C). +TEST_F(AlgebraicSimplifierTest, DivOfPower) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* power = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Power(param1, param2))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(param0, op::Power(param1, op::Negate(param2)))); +} + // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -145,7 +294,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); @@ -167,7 +316,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); + Literal::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); @@ -239,6 +388,89 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { op::Exp(op::Subtract(param0, param1))); } +// Test that exp(A)*exp(B) is simplified to exp(A+B) +TEST_F(AlgebraicSimplifierTest, ExpMul) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Exp(param0), op::Exp(param1))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Exp(op::Add(param0, param1))); +} + +// Test that pow(exp(A), B) is simplified to exp(A*B) +TEST_F(AlgebraicSimplifierTest, PowExp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Power(op::Exp(param0), param1)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Exp(op::Multiply(param0, param1))); +} + +// Test that ln(pow(A, B)) is simplified to ln(A)*B +TEST_F(AlgebraicSimplifierTest, LnPow) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* pow = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1)); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Log(op::Power(param0, param1))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Log(param0), param1)); +} + // Test that ln(exp(A)) is simplified to A TEST_F(AlgebraicSimplifierTest, LnExp) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -300,7 +532,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); @@ -315,7 +547,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->literal()), 1); + EXPECT_EQ(root->literal().GetFirstElement(), 1); } // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). @@ -325,7 +557,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); @@ -344,8 +576,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape())); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), - 1); + EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); } // Test that pow(A, 1) is simplified to A. @@ -355,7 +586,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); @@ -378,7 +609,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* two = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); + HloInstruction::CreateConstant(Literal::CreateR0(2))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); @@ -401,7 +632,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* negative_one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(-1))); + HloInstruction::CreateConstant(Literal::CreateR0(-1))); builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); @@ -416,8 +647,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Constant(), param0)); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), - 1); + EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -451,7 +681,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -519,7 +749,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction::CreateConstant(Literal::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1})); @@ -550,7 +780,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction::CreateConstant(Literal::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1})); @@ -735,7 +965,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), HloOpcode::kMaximum, movable_reshape, zero)); @@ -753,6 +983,34 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { op::Reshape(op::Maximum(param, zero))); } +// Regression test for a bug in the reshape sinking transformation, where +// moving a reshape to a scalar led to a crash. +TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1}), "param")); + HloInstruction* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param)); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, reshape, zero)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Reshape(param), zero)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + + simplifier.Run(module.get()).ValueOrDie(); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Reshape(param), zero)); +} + TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -1035,7 +1293,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 2}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); PaddingConfig no_padding; for (int i = 0; i < 2; ++i) { auto dimension = no_padding.add_dimensions(); @@ -1066,7 +1324,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {10, 10}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); PaddingConfig padding; int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {2, -3}; @@ -1134,7 +1392,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param")); builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, - /*limit_indices=*/{dim0, dim1}, /*slices=*/{1, 1})); + /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(builder.Build()); @@ -1376,9 +1634,9 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, param0, min_value)); builder.AddInstruction( @@ -1406,9 +1664,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); builder.AddInstruction( @@ -1437,9 +1695,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kMaximum, param0, max_value)); builder.AddInstruction( @@ -1497,9 +1755,9 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); HloInstruction* fmax = builder.AddInstruction( @@ -1566,7 +1824,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloComputation::Builder builder(TestName()); HloInstruction* forty_two = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); HloInstruction* broadcast = @@ -1614,7 +1872,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); @@ -1645,7 +1903,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, pad, reduce_init_value, window, @@ -1714,9 +1972,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloComputation::Builder call_builder(TestName() + ".Call"); HloInstruction* zero = call_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({0.0f}))); HloInstruction* one = call_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({1.0f}))); builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); @@ -1728,6 +1986,26 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); } +// Test that a constant with tuple shape becomes a tuple of constants. +TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { + HloComputation::Builder builder(TestName()); + const float constant_scalar = 7.3f; + std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; + std::unique_ptr value = + Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), + Literal::CreateR1(constant_vector).get()}); + builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(op::Constant(), op::Constant())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 66d54ad3802fe442decd11335eddf74bdd1cf950..9abe30e3f371cc294c36c1dcd743224b11b0c4f5 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -22,7 +22,6 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const { return platform_; } -BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) { - number_of_replicas_ = number_of_replicas; - return *this; -} - -int BackendOptions::number_of_replicas() const { return number_of_replicas_; } - BackendOptions& BackendOptions::set_intra_op_parallelism_threads( int num_threads) { intra_op_parallelism_threads_ = num_threads; @@ -85,20 +77,17 @@ struct Backend::EigenThreadPoolWrapper { /* static */ StatusOr> Backend::CreateBackend( const BackendOptions& options) { - int64 replica_count = options.number_of_replicas(); - if (replica_count == -1) { - legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); - replica_count = flags->xla_replicas; - } perftools::gputools::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); + TF_ASSIGN_OR_RETURN(auto computation_placer, + ComputationPlacer::GetForPlatform(platform)); std::unique_ptr backend( - new Backend(replica_count, platform, compiler, stream_executors, - transfer_manager, options.intra_op_parallelism_threads())); + new Backend(platform, compiler, stream_executors, transfer_manager, + computation_placer, options.intra_op_parallelism_threads())); return std::move(backend); } @@ -132,34 +121,25 @@ StatusOr Backend::BorrowStream( } Backend::Backend( - int64 replica_count, perftools::gputools::Platform* platform, - Compiler* compiler, + perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads) + TransferManager* transfer_manager, ComputationPlacer* computation_placer, + int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), - replica_count_(replica_count) { + computation_placer_(computation_placer) { // The given set of stream executors set may include invalid executors. for (se::StreamExecutor* exec : stream_executors) { if (exec != nullptr) { stream_executors_.push_back(exec); } } - CHECK_GE(replica_count, 1) << "Must request at least 1 replica."; - // Create a memory allocator for the valid stream executors. memory_allocator_ = MakeUnique(platform, stream_executors); - - // First check that there are some non-null stream executors to avoid issuing - // an error mentioning replicas in the common case of requesting just 1 - // replica, which means no replication. CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; - CHECK_GE(stream_executors_.size(), replica_count) - << "Requested more replicas than there are devices for backend " - << platform_->Name() << '.'; if (platform->id() == se::host::kHostPlatformId) { inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( @@ -179,36 +159,6 @@ int Backend::default_device_ordinal() const { return default_stream_executor()->device_ordinal(); } -StatusOr> Backend::Replicas( - int device_ordinal) const { - if (stream_executors_[device_ordinal] == nullptr) { - return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); - } - - // Find replica_count_ stream executors starting from the given device - // ordinal. - std::vector replicas; - for (se::StreamExecutor* exec : stream_executors_) { - CHECK(exec != nullptr); - if (exec->device_ordinal() >= device_ordinal) { - replicas.push_back(exec); - if (replicas.size() >= replica_count_) { - return replicas; - } - } - } - - return InvalidArgument( - "Not enough devices for replicas for the device ordinal %d", - device_ordinal); -} - -std::vector Backend::Replicas() const { - CHECK_GE(stream_executors_.size(), replica_count_); - return Replicas(default_device_ordinal()).ValueOrDie(); -} - tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { return inter_op_thread_pool_.get(); } diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index e0b15dc43f25244bc1a3e3c5cdc45877d4d11804..b5ca483b7274d20c31e932d748b6a4c9dea926f9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -46,12 +47,6 @@ class BackendOptions { BackendOptions& set_platform(perftools::gputools::Platform* platform); perftools::gputools::Platform* platform() const; - // Set the number of replicas to use when compiling replicated - // programs. The default is -1 meaning that the value is read from - // the xla_replicas flag. - BackendOptions& set_number_of_replicas(int number_of_replicas); - int number_of_replicas() const; - // Sets the thread pool size for parallel execution of an individual operator. // The default value of -1 will result in initializing the thread pool with // the number of threads equal to the number of cores in the system. @@ -60,7 +55,6 @@ class BackendOptions { private: perftools::gputools::Platform* platform_ = nullptr; - int number_of_replicas_ = -1; int intra_op_parallelism_threads_ = -1; }; @@ -74,8 +68,7 @@ class Backend { public: using StreamPtr = Pool::SmartPtr; - // Creates a new backend for the given platform with the given number of - // replicas. + // Creates a new backend. static StatusOr> CreateBackend( const BackendOptions& options); @@ -92,6 +85,7 @@ class Backend { return memory_allocator_.get(); } TransferManager* transfer_manager() const { return transfer_manager_; } + ComputationPlacer* computation_placer() const { return computation_placer_; } // Returns the number of devices of the platform type which are visible. Not // all of these devices may be usable by XLA. @@ -107,24 +101,13 @@ class Backend { return stream_executors_; } - // Returns the replicas for the default stream executor. - // - // When the number of replicas is R, the first R stream executors are assigned - // to the replicas of the default stream executor. - std::vector Replicas() const; - - // Returns the replicas for the given device_ordinal. The given device ordinal - // is considered to be the first device ordinal among the replicas. Returns an - // error status if the stream executor for the given given device ordinal does - // not exist or if there are not enough stream executors for the replicas. - StatusOr> Replicas( - int device_ordinal) const; - - // Return the stream executor for the given device ordinal. + // Returns the stream executor for the given device ordinal. StatusOr stream_executor( int device_ordinal) const; - // Return the stream executor for the default device ordinal. + // Returns the stream executor for the default device ordinal. This stream + // executor can only be used when the number of computations is 1 (replication + // can be > 1). perftools::gputools::StreamExecutor* default_stream_executor() const { CHECK(!stream_executors_.empty()); return stream_executors_[0]; @@ -174,18 +157,19 @@ class Backend { private: struct EigenThreadPoolWrapper; - Backend(int64 replica_count, perftools::gputools::Platform* platform, - Compiler* compiler, + Backend(perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads); + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; perftools::gputools::Platform* platform_; Compiler* compiler_; TransferManager* transfer_manager_; - int64 replica_count_ = -1; + ComputationPlacer* computation_placer_; // Vector of stream executors. stream_executors_[0] is the default executor. std::vector stream_executors_; diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca2d413e11d3ad12bb3cac7695386c3089a21b1b --- /dev/null +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -0,0 +1,286 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// BatchNormRewriterVisitor traverses the HLO computation and rewrites BatchNorm +// operations into smaller operations. +class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleBatchNormTraining(HloInstruction* batch_norm) override; + + // Runs the visitor on a computation. + static bool Run(HloComputation* computation, bool rewrite_training_op, + bool rewrite_grad_op); + + // Returns whether any batch norm ops were rewritten. + const bool changed() const { return changed_; } + + ~BatchNormRewriterVisitor() override = default; + + private: + explicit BatchNormRewriterVisitor(HloComputation* computation, + bool rewrite_training_op, + bool rewrite_grad_op) + : computation_(computation), + rewrite_training_op_(rewrite_training_op), + rewrite_grad_op_(rewrite_grad_op) {} + + HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, + HloOpcode opcode) { + HloComputation::Builder b("scalar computation"); + auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs")); + auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs")); + auto scalar_op = b.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), + opcode, scalar_lhs, scalar_rhs)); + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + } + + // Current HloComputation instance the BatchNormRewriter is + // traversing. + HloComputation* computation_; + + bool rewrite_training_op_; + bool rewrite_grad_op_; + + // Whether rewrite has occurred. + bool changed_ = false; + + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceWithNewInstruction( + HloInstruction* old_instruction, + std::unique_ptr new_instruction) { + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + old_instruction, std::move(new_instruction))); + changed_ = true; + return Status::OK(); + } + + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction) { + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(old_instruction, new_instruction)); + changed_ = true; + return Status::OK(); + } +}; + +bool BatchNormRewriterVisitor::Run(HloComputation* computation, + bool rewrite_training_op, + bool rewrite_grad_op) { + BatchNormRewriterVisitor visitor(computation, + /*rewrite_training_op=*/rewrite_training_op, + /*rewrite_grad_op=*/rewrite_grad_op); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; +} + +Status BatchNormRewriterVisitor::HandleBatchNormTraining( + HloInstruction* batch_norm) { + if (!rewrite_training_op_) { + return Status::OK(); + } + // Expand batch norm training into smaller HLO ops. + HloInstruction* operand = batch_norm->mutable_operand(0); + const Shape operand_shape = operand->shape(); + int64 feature_index = batch_norm->feature_index(); + const int64 feature_count = operand_shape.dimensions(feature_index); + const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); + auto elements_per_feature = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(size_in_elements / feature_count))); + + HloInstruction* scale = batch_norm->mutable_operand(1); + HloInstruction* offset = batch_norm->mutable_operand(2); + const Shape feature_shape = scale->shape(); + + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + + auto epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + + std::vector dimensions_without_feature; + + for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + if (i != feature_index) { + dimensions_without_feature.push_back(i); + } + } + + auto scale_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); + + auto offset_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); + + HloComputation* add_reduce_computation = + GetScalarBinaryComputation(F32, HloOpcode::kAdd); + + // X^2. + auto operand_squared = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, operand, operand)); + // Sum[X]. + auto sum = computation_->AddInstruction(HloInstruction::CreateReduce( + feature_shape, operand, zero, dimensions_without_feature, + add_reduce_computation)); + + // Sum[X^2]. + auto squared_sum = computation_->AddInstruction(HloInstruction::CreateReduce( + feature_shape, operand_squared, zero, dimensions_without_feature, + add_reduce_computation)); + + // Fuse two parallel reduces together to improve performance. + auto tuple = computation_->AddInstruction( + HloInstruction::CreateTuple({sum, squared_sum})); + + auto fused = computation_->CreateFusionInstruction( + {tuple, sum, squared_sum, operand_squared}, + HloInstruction::FusionKind::kInput); + + sum = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + + squared_sum = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + + // E[X]. + auto mean = computation_->AddInstruction(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); + + auto mean_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); + + // E[X^2]. + auto square_mean = computation_->AddInstruction(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); + + // E^2[X]. + auto mean_square = computation_->AddInstruction(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kMultiply, mean, mean)); + + // Var[X]. + auto var = computation_->AddInstruction(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); + + auto var_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + + // Var[X] + epsilon. + auto var_add_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + + auto neg_half = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + + // 1 / Sqrt[Var[X] + epsilon]. + auto rsqrt_var_add_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + + // X - E[X]. + auto operand_minus_mean = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon]. + auto normalized = computation_->AddInstruction( + HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. + auto scaled_normalized = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. + auto shifted_normalized = computation_->AddInstruction( + HloInstruction::CreateBinary(operand_shape, HloOpcode::kAdd, + scaled_normalized, offset_broadcasted)); + + TF_CHECK_OK(ReplaceWithNewInstruction( + batch_norm, + HloInstruction::CreateTuple({shifted_normalized, mean, var}))); + return Status::OK(); +} + +StatusOr BatchNormRewriter::Run(HloModule* module) { + XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), before:\n" + module->ToString()); + bool changed = false; + // Make a copy of the computations because we may add computations to the + // module, invalidating iteration. + std::vector computations; + for (auto& comp : module->computations()) { + if (comp->IsFusionComputation()) { + continue; + } + computations.push_back(comp.get()); + } + for (auto& comp : computations) { + if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_, + rewrite_grad_op_)) { + changed = true; + } + } + XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..6d176f4849a786e8650013c430527959bdd004a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which rewrites batch norm operations into more operations. Breaking a +// big operation into smaller operations helps leverage our generic fusion +// logic. +class BatchNormRewriter : public HloPassInterface { + public: + BatchNormRewriter(bool rewrite_training_op = false, + bool rewrite_grad_op = false) + : rewrite_training_op_(rewrite_training_op), + rewrite_grad_op_(rewrite_grad_op) {} + ~BatchNormRewriter() = default; + tensorflow::StringPiece name() const override { return "batchnorm_rewriter"; } + + // Run operation expander on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + bool rewrite_training_op_; + bool rewrite_grad_op_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..301f31b51ceb9700b71a86ceddf0065dee93b121 --- /dev/null +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" + +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +using BatchNormRewriterTest = HloTestBase; + +// Test that we expand BatchNormTraining. +TEST_F(BatchNormRewriterTest, BatchNormTraining) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); + Shape scale_shape = ShapeUtil::MakeShape(F32, {2}); + Shape offset_shape = ShapeUtil::MakeShape(F32, {2}); + + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "activiation")); + + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scale_shape, "scale")); + + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, offset_shape, "offset")); + + builder.AddInstruction(HloInstruction::CreateBatchNormTraining( + ShapeUtil::MakeTupleShape({input_shape, scale_shape, offset_shape}), + param0, param1, param2, + /*epsilon=*/0.001, /*feature_index=*/3)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); + BatchNormRewriter rewriter(/*rewrite_training_op=*/true); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Make sure this operation is expanded. + EXPECT_EQ(root->opcode(), HloOpcode::kTuple); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index f91eb0207a23fe55394d59ed99a0d08cf16aa285..ae31135a1aeb2807649aceb6e77d6050525ce5a6 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,12 +22,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -66,6 +66,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { + VLOG(4) << "Trying to add " << buffer << " to " << this; CHECK(assigned_buffers_.count(&buffer) == 0) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -212,10 +213,14 @@ bool BufferAssignment::HasTopLevelAllocation( StatusOr BufferAssignment::GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const { + VLOG(3) << "Trying to find unique slice for " << instruction->name() << " [" + << index << "]"; BufferAllocation::Slice result; for (const LogicalBuffer* buffer : GetPointsToSet(instruction).element(index)) { + VLOG(3) << "Examining buffer " << *buffer; if (HasAllocation(*buffer)) { + VLOG(3) << "Has allocation"; const BufferAllocation::Slice slice = GetAssignedAllocation(*buffer).GetSlice(*buffer); if (result.allocation() == nullptr) { @@ -226,6 +231,8 @@ StatusOr BufferAssignment::GetUniqueSlice( "be determined at compile-time.", instruction->name().c_str(), index.ToString().c_str()); } + } else { + VLOG(3) << "No allocation"; } } if (result.allocation() == nullptr) { @@ -320,8 +327,9 @@ void BufferAssignment::CombineTempAllocations() { // Each temp allocation is placed end-to-end, accounting for alignment. // The offset of each buffer in the combined allocation is computed from // the base offset of the allocation. + int64 alignment = color_alignment_(color); const int64 base = - RoundUpToNearest(combined_allocation->size(), alignment_); + RoundUpToNearest(combined_allocation->size(), alignment); combined_allocation->set_size(base + temp_allocation.size()); for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) { const LogicalBuffer* buffer = buffer_offset_size.first; @@ -575,12 +583,13 @@ Status GatherComputationsByAllocationType( /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, int64 alignment, - bool allow_input_output_aliasing, TuplePointsToAnalysis::Colorer colorer) { - BufferAssigner assigner(alignment, allow_input_output_aliasing, - std::move(colorer)); + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment, + bool allow_input_output_aliasing, BufferLiveness::Colorer colorer) { + BufferAssigner assigner(allow_input_output_aliasing, std::move(colorer)); return assigner.CreateAssignment(module, std::move(hlo_ordering), - std::move(buffer_size)); + std::move(buffer_size), + std::move(color_alignment)); } bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, @@ -662,7 +671,8 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } Status BufferAssigner::AssignBuffersForComputation( - const HloComputation* computation, bool is_thread_local, + const HloComputation* computation, const DebugOptions& debug_options, + bool is_thread_local, const FlatSet& colocated_buffers, const FlatSet& colocated_allocations, FlatMap>* @@ -786,10 +796,7 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - legacy_flags::BufferAssignmentFlags* flags = - legacy_flags::GetBufferAssignmentFlags(); - if (!flags->xla_enable_buffer_reuse || is_thread_local || - instruction->opcode() == HloOpcode::kCustomCall) { + if (is_thread_local || instruction->opcode() == HloOpcode::kCustomCall) { // Custom call operations never have reusable buffers. Also we do not // reuse thread-local buffers for now, because they are dynamically // allocated and their lifetimes are hard to compute. @@ -938,11 +945,13 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( } auto color_map = SplitBuffersByColor(all_buffers_to_assign); for (auto& single_colored_set : color_map) { - VLOG(2) << "Simulating heap for color " << single_colored_set.first; + auto color = single_colored_set.first; + VLOG(2) << "Simulating heap for color " << color; + int64 alignment = assignment->color_alignment_(color); TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( - MakeUnique(alignment_)), + MakeUnique(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), assignment->buffer_size_, @@ -963,11 +972,13 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( CHECK(instruction_sequence != nullptr) << computation->name(); auto color_map = SplitBuffersByColor(buffers_to_assign); for (auto& single_colored_set : color_map) { - VLOG(2) << "Simulating heap for color " << single_colored_set.first; + auto color = single_colored_set.first; + VLOG(2) << "Simulating heap for color " << color; + int64 alignment = assignment->color_alignment_(color); TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( - MakeUnique(alignment_)), + MakeUnique(alignment)), *computation, *instruction_sequence, assignment->points_to_analysis(), assignment->buffer_size_, @@ -1074,7 +1085,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( // different while instructions. void BufferAssigner::AddWhileSetToColocatedBufferSets( const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const LogicalBuffer* while_init_buffer, + const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, const HloComputation& computation, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets) { @@ -1137,16 +1149,30 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets( continue; } - // Skip predecessor set if the live range of any predecessor buffers - // overlaps with 'while_init_buffer'. Note that tuple element buffer - // forwarding can cause the same buffer to appear on both sides of the - // interference comparison below. - if (std::any_of( - predecessor_while_buffers.begin(), predecessor_while_buffers.end(), - [while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) { - return while_init_buffer->id() != buffer->id() && - buffer_liveness.MayInterfere(*while_init_buffer, *buffer); - })) { + // Skip predecessor set if the live range of any predecessor + // buffers overlaps with 'while_init_buffer' or + // 'while_result_buffer' (we need to check both since they're + // aliased together, but the points-to analysis is unaware of this + // aliasing). Note that tuple element buffer forwarding can cause + // the same buffer to appear on both sides of the interference + // comparison below. + auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) { + if (while_init_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) { + return true; + } + + if (while_result_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) { + return true; + } + + return false; + }; + + if (std::any_of(predecessor_while_buffers.begin(), + predecessor_while_buffers.end(), + may_interfere_with_init_or_result)) { continue; } @@ -1193,6 +1219,9 @@ void BufferAssigner::BuildColocatedBufferSets( const TuplePointsToAnalysis& points_to_analysis = buffer_liveness.points_to_analysis(); for (const HloComputation* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; + } for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { const HloOpcode opcode = instruction->opcode(); @@ -1209,8 +1238,8 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet(while_hlo->operand(0), index, points_to_analysis, &colocated_set); // Add while.result. - AddBufferToColocatedSet(while_hlo, index, points_to_analysis, - &colocated_set); + auto* result_buffer = AddBufferToColocatedSet( + while_hlo, index, points_to_analysis, &colocated_set); // Add while.cond.parameter. AddBufferToColocatedSet( while_hlo->while_condition()->parameter_instruction(0), index, @@ -1224,8 +1253,9 @@ void BufferAssigner::BuildColocatedBufferSets( while_hlo->while_body()->root_instruction(), index, points_to_analysis, &colocated_set); AddWhileSetToColocatedBufferSets( - colocated_set, init_buffer, while_hlo, *computation, - buffer_liveness, buffer_size, colocated_buffer_sets); + colocated_set, init_buffer, result_buffer, while_hlo, + *computation, buffer_liveness, buffer_size, + colocated_buffer_sets); }); } else if (opcode == HloOpcode::kCall) { const HloInstruction* call_hlo = instruction; @@ -1300,10 +1330,10 @@ void BufferAssigner::AssignColocatedBufferSets( StatusOr> BufferAssigner::CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size) { + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) { TF_ASSIGN_OR_RETURN(std::unique_ptr liveness, - BufferLiveness::Run(module, std::move(hlo_ordering), - std::move(colorer_))); + BufferLiveness::Run(module, std::move(hlo_ordering))); VLOG(1) << "Assigning buffers to module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); @@ -1311,8 +1341,9 @@ StatusOr> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); // Can't use MakeUnique because BufferAssignment constructor is private. - std::unique_ptr assignment(new BufferAssignment( - module, std::move(liveness), alignment_, std::move(buffer_size))); + std::unique_ptr assignment( + new BufferAssignment(module, std::move(liveness), std::move(buffer_size), + std::move(color_alignment))); // Assign buffers with the tightest constraints first (colocated buffer sets). // Once b/32491382 enables module-level liveness analysis, we may be able @@ -1323,6 +1354,10 @@ StatusOr> BufferAssigner::CreateAssignment( std::vector colocated_buffer_sets; BuildColocatedBufferSets(module, assignment->liveness(), assignment->buffer_size_, &colocated_buffer_sets); + TF_RETURN_IF_ERROR(colorer_(assignment->liveness())); + VLOG(3) << "After coloring:"; + XLA_VLOG_LINES(3, assignment->points_to_analysis().ToString()); + AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), &colocated_buffers, &colocated_allocations); @@ -1337,9 +1372,9 @@ StatusOr> BufferAssigner::CreateAssignment( buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/false, colocated_buffers, - colocated_allocations, &buffers_to_assign_sequentially, - assignment.get())); + computation, module->config().debug_options(), + /*is_thread_local=*/false, colocated_buffers, colocated_allocations, + &buffers_to_assign_sequentially, assignment.get())); } // Assign buffers with sequential ordering, if any. If all global computations // are sequential, we can run heap simuation on the whole module, which @@ -1354,10 +1389,13 @@ StatusOr> BufferAssigner::CreateAssignment( // their own BufferAllocation. for (auto* computation : thread_local_computations) { TF_RET_CHECK(computation != module->entry_computation()); + if (computation->IsFusionComputation()) { + continue; + } TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/true, colocated_buffers, - colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, - assignment.get())); + computation, module->config().debug_options(), + /*is_thread_local=*/true, colocated_buffers, colocated_allocations, + /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); } // Mark all buffers which may be live out of the entry computation as diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index b3933f11c1e6ae3e7ffcc990442183338788caf4..35c904df130564a4848d3cb2db21ed8fa209e7e8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -351,12 +351,12 @@ class BufferAssignment { explicit BufferAssignment(const HloModule* module, std::unique_ptr liveness, - int64 alignment, - LogicalBuffer::SizeFunction buffer_size) + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) : module_(module), liveness_(std::move(liveness)), - alignment_(alignment), - buffer_size_(std::move(buffer_size)) {} + buffer_size_(std::move(buffer_size)), + color_alignment_(std::move(color_alignment)) {} // Creates and returns a new BufferAllocation, with no assigned // LogicalBuffers. Ownership is maintained internally. @@ -402,11 +402,13 @@ class BufferAssignment { const HloModule* module_; const std::unique_ptr liveness_; - const int64 alignment_; // Function which returns the buffer size for a given logical buffer (shape). LogicalBuffer::SizeFunction buffer_size_; + // Function which returns the alignment for a given logical buffer color. + LogicalBuffer::AlignmentFunction color_alignment_; + Stats stats_; std::vector heap_simulator_traces_; @@ -417,36 +419,37 @@ class BufferAssignment { class BufferAssigner { public: // Build and return a BufferAssignment for the given module. The given - // HloOrdering is used to determine buffer liveness. buffer_size is a function - // which returns the size of a LogicalBuffer. Alignment is the minimum - // alignment of any buffer. allow_input_output_aliasing specifies whether - // input buffer are allowed to be reused as outbut buffers by the client code. + // HloOrdering is used to determine buffer liveness. buffer_size and + // color_alignment are functions which returns the size and alignment of a + // LogicalBuffer. allow_input_output_aliasing specifies whether input buffer + // are allowed to be reused as outbut buffers by the client code. static StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, int64 alignment, + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing = false, - TuplePointsToAnalysis::Colorer colorer = - TuplePointsToAnalysis::DefaultColorer()); + BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer()); private: - BufferAssigner(int64 alignment, bool allow_input_output_aliasing, - TuplePointsToAnalysis::Colorer colorer) - : alignment_(alignment), - allow_input_output_aliasing_(allow_input_output_aliasing), + BufferAssigner(bool allow_input_output_aliasing, + BufferLiveness::Colorer colorer) + : allow_input_output_aliasing_(allow_input_output_aliasing), colorer_(colorer) {} virtual ~BufferAssigner() = default; // Create a buffer assignment. StatusOr> CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size); + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment); // Assigns buffers to the instructions in the given computation. "assignment" // is modified to reflect the new buffer assignments. If is_thread_local is // true, then all assigned buffers have the is_thread_local flag set to // true. Status AssignBuffersForComputation( - const HloComputation* computation, bool is_thread_local, + const HloComputation* computation, const DebugOptions& debug_options, + bool is_thread_local, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, @@ -511,7 +514,8 @@ class BufferAssigner { // colocated buffers for while instructions. void AddWhileSetToColocatedBufferSets( const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const LogicalBuffer* while_init_buffer, + const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, const HloComputation& computation, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets); @@ -524,15 +528,12 @@ class BufferAssigner { SplitBuffersByColor( const tensorflow::gtl::FlatSet& buffers); - // Minimum alignment of any buffer. - int64 alignment_; - // If true, buffer assignments assumes that input parameter buffers and output // buffers can be shared if their sizes match. bool allow_input_output_aliasing_; // Functor used to assign colors to newly allocated logical buffers. - TuplePointsToAnalysis::Colorer colorer_; + BufferLiveness::Colorer colorer_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 892f67a8812823a6f156dc6098bf6b39fa800d3c..18acd4f3ae47882bf629c090c510db92049e215a 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -85,17 +86,18 @@ class BufferAssignmentTest : public HloTestBase { int64 alignment = 1) { return BufferAssigner::Run( module, MakeUnique(module), - backend_->compiler()->BufferSizeBytesFunction(), alignment) + backend_->compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); } std::unique_ptr RunColoredBufferAssignment( - HloModule* module, TuplePointsToAnalysis::Colorer colorer, - int64 alignment = 1) { - return BufferAssigner::Run(module, - MakeUnique(module), - backend_->compiler()->BufferSizeBytesFunction(), - alignment, false, std::move(colorer)) + HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { + return BufferAssigner::Run( + module, MakeUnique(module), + backend_->compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }, false, + std::move(colorer)) .ConsumeValueOrDie(); } @@ -105,7 +107,7 @@ class BufferAssignmentTest : public HloTestBase { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); return builder.Build(); @@ -122,7 +124,7 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + HloInstruction::CreateConstant(Literal::CreateR0(4))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( @@ -147,9 +149,9 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto constv = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto indexc = builder.AddInstruction( @@ -264,7 +266,7 @@ static bool BuffersDistinct(const std::vector& a, TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -278,9 +280,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) { // no buffers assigned, and their consumer has a buffer. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); + Literal::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); auto module = CreateNewModule(); @@ -298,7 +300,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { // This computation copies a constant to output. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); auto module = CreateNewModule(); @@ -378,12 +380,16 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunColoredBufferAssignment( - module.get(), - [](const HloInstruction* instruction, const ShapeIndex& index) { - static int64 serial = 0; - return LogicalBuffer::Color(serial++); - }); + auto colorer = [](const BufferLiveness& buffer_liveness) { + int color = 0; + for (auto& buffer : + buffer_liveness.points_to_analysis().logical_buffers()) { + buffer->set_color(LogicalBuffer::Color(color++)); + } + return Status::OK(); + }; + + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -430,14 +436,25 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunColoredBufferAssignment( - module.get(), - [](const HloInstruction* instruction, const ShapeIndex& index) { - return (instruction->opcode() == HloOpcode::kAdd || - instruction->opcode() == HloOpcode::kMultiply) - ? LogicalBuffer::Color(1) - : LogicalBuffer::Color(0); - }); + auto colorer = [](const BufferLiveness& buffer_liveness) { + for (auto& buffer : + buffer_liveness.points_to_analysis().logical_buffers()) { + const auto& aliases = + buffer_liveness.points_to_analysis().GetBufferAliases(*buffer); + for (const auto& alias : aliases) { + if (alias.instruction()->opcode() == HloOpcode::kAdd || + alias.instruction()->opcode() == HloOpcode::kMultiply) { + buffer->set_color(LogicalBuffer::Color(1)); + } + } + if (!buffer->has_color()) { + buffer->set_color(LogicalBuffer::Color(0)); + } + } + return Status::OK(); + }; + + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -586,7 +603,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto exp2 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( /*shape=*/f32vec10_, /*operand=*/exp2, @@ -634,9 +651,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto const3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( @@ -996,9 +1013,10 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { // Test a computation that returns a tuple parameter. auto builder = HloComputation::Builder(TestName()); auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), - ShapeUtil::MakeShape(F32, {}), - ShapeUtil::MakeShape(S32, {42})}), + 0, + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {42})}), "param0")); auto module = CreateNewModule(); @@ -1027,10 +1045,11 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { // parameter. auto builder = HloComputation::Builder(TestName()); auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}), - ShapeUtil::MakeShape(S32, {101})})}), + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}), + ShapeUtil::MakeShape(S32, {101})})}), "param0")); auto tuple_element = builder.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -1075,9 +1094,8 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output is // properly handled. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}))); + builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple( + {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1369,9 +1387,9 @@ class WhileBufferAssignmentTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto ten = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + HloInstruction::CreateConstant(Literal::CreateR0(10))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); return builder.Build(); @@ -1399,7 +1417,8 @@ class WhileBufferAssignmentTest : public HloTestBase { CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, MakeUnique(module, sequence), - ByteSizeOf, alignment) + ByteSizeOf, + [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); } @@ -1429,7 +1448,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateParameter(2, data_shape_, "weights1")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1484,7 +1503,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1532,16 +1551,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param")); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); sub_computation = module->AddEmbeddedComputation(builder.Build(add)); } auto builder = HloComputation::Builder(TestName()); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto call1 = builder.AddInstruction( HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); auto call2 = builder.AddInstruction( @@ -1554,7 +1573,7 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); std::unique_ptr call_graph = CallGraph::Build(module.get()); } @@ -1565,6 +1584,105 @@ TEST_F(BufferAssignmentTest, TwoCalls) { EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } +static bool IsPostOrderTraversal( + const std::vector& sequence) { + tensorflow::gtl::FlatSet seen_so_far; + auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { + return seen_so_far.count(instruction) == 0; + }; + + for (auto instruction : sequence) { + if (std::any_of(instruction->operands().begin(), + instruction->operands().end(), has_not_been_seen_yet) || + std::any_of(instruction->control_predecessors().begin(), + instruction->control_predecessors().end(), + has_not_been_seen_yet)) { + return false; // Not a post order. + } + if (!seen_so_far.insert(instruction).second) { + return false; // Not a "traversal". + } + } + + return true; +} + +TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder(TestName()); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto input1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, data_shape_, "input1")); + auto weights1 = builder.AddInstruction( + HloInstruction::CreateParameter(3, data_shape_, "weights1")); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, one, {1})); + + auto cond = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input1, weights1, output1})); + + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); + + auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( + while0->shape(), HloOpcode::kAdd, while0, while1)); + module->AddEntryComputation(builder.Build()); + + RunCopyInsertion(module.get()); + + { + FlattenCallGraph flatten; + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + } + + auto sequence = + CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + + // To trigger b/38494731, we want a specific Hlo sequence for the + // root computation, so we overwrite that entry with a manually + // crafted sequence. + std::vector sequence_for_buffer_assigment = { + input1, weights1, one, output1, tuple1, while1, input0, + weights0, zero, output0, tuple0, while0, root_add}; + + // If this ASSERT_TRUE fails, we constructed a bogus sequence above + // and this test itself is buggy. + ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment)); + + sequence[module->entry_computation()] = + std::move(sequence_for_buffer_assigment); + + auto assignment = + BufferAssigner::Run( + module.get(), + MakeUnique(module.get(), sequence), ByteSizeOf, + [](LogicalBuffer::Color) { return 1; }) + .ConsumeValueOrDie(); + + EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); +} + // Test buffer assignment for while nodes with multiple uses. // TODO(b/37245345): Fix buffer assignment for this case. TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { @@ -1577,7 +1695,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -1605,7 +1723,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { { FlattenCallGraph flatten; - TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 1b14c26340f6c1922bf35457fe7f1367ed953df0..f085ffa6bc40b212339a97604455a07c1e662952 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -37,18 +37,19 @@ namespace xla { /* static */ StatusOr> BufferLiveness::Run( - const HloModule* module, std::unique_ptr hlo_ordering, - TuplePointsToAnalysis::Colorer colorer) { + const HloModule* module, std::unique_ptr hlo_ordering) { std::unique_ptr liveness( - new BufferLiveness(module, std::move(hlo_ordering), std::move(colorer))); + new BufferLiveness(module, std::move(hlo_ordering))); TF_RETURN_IF_ERROR(liveness->Analyze()); return std::move(liveness); } tensorflow::Status BufferLiveness::Analyze() { - TF_ASSIGN_OR_RETURN(points_to_analysis_, - TuplePointsToAnalysis::Run(module_, colorer_)); + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto& computation : module_->computations()) { + if (computation->IsFusionComputation()) { + continue; + } // Gather all instructions whose buffers might alias other instructions into // the set aliased_buffers_. This includes those contained as a tuple // element in other instruction's output. @@ -122,7 +123,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), b.instruction(), b.index(), - points_to_analysis())) { + &points_to_analysis())) { return false; } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 9bb2564a8312f0d80e01f40cb18f99d5ad0e1771..70d642b40c8e4f51748f736c69795a94ccc30de2 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -36,12 +36,12 @@ namespace xla { // interference. class BufferLiveness { public: + using Colorer = std::function; + // Constructs a buffer liveness object for the given module assuming the given // HLO instruction ordering. static StatusOr> Run( - const HloModule* module, std::unique_ptr hlo_ordering, - TuplePointsToAnalysis::Colorer colorer = - TuplePointsToAnalysis::DefaultColorer()); + const HloModule* module, std::unique_ptr hlo_ordering); // Returns true if the live range of the buffer containing the output of 'a' // may overlap with the live range of the buffer of 'b'. If instruction 'a' @@ -67,15 +67,24 @@ class BufferLiveness { // Returns the underlying hlo ordering used for this liveness analysis. const HloOrdering& hlo_ordering() const { return *hlo_ordering_; } + const HloModule& module() const { return *module_; } + string ToString() const; + static Colorer DefaultColorer() { + return [](const BufferLiveness& buffer_liveness) { + for (auto& buffer : + buffer_liveness.points_to_analysis().logical_buffers()) { + buffer->set_color(LogicalBuffer::Color(0)); + } + return Status::OK(); + }; + } + private: explicit BufferLiveness(const HloModule* module, - std::unique_ptr hlo_ordering, - TuplePointsToAnalysis::Colorer colorer) - : module_(module), - hlo_ordering_(std::move(hlo_ordering)), - colorer_(colorer) {} + std::unique_ptr hlo_ordering) + : module_(module), hlo_ordering_(std::move(hlo_ordering)) {} // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. @@ -98,8 +107,6 @@ class BufferLiveness { tensorflow::gtl::FlatSet maybe_live_out_buffers_; std::unique_ptr points_to_analysis_; - - TuplePointsToAnalysis::Colorer colorer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index fda44ff4d2df18b90d308617cf845c9946227249..a5f7cc0aebe856931a122eb4bf56f87666ee38a0 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -397,13 +397,11 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + auto inner_tuple0 = Literal::MakeTuple( + {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}); + auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0(3).get()}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0->shape(), tuple_constant, 0)); @@ -450,7 +448,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -462,7 +460,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element1_shape, tuple_param0, 1)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1)); @@ -513,7 +511,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -585,7 +583,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); HloInstruction* slice = nullptr; if (update_uses_tuple_element1) { // Create a slice instruction as an additional user of 'gte1'. @@ -596,7 +594,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -715,7 +713,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); if (tuple_element1_has_two_uses) { // Add 'gte0' and 'gte1' to create another user of 'gte1'. @@ -724,7 +722,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index fa7b2a309525dd80d655e10474c5d49f9da14ea8..b450e0c40074344778109ed2ba8b2238cff7940e 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -133,6 +133,37 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { return nodes_[it->second]; } +bool CallGraph::DominatesHelper( + const HloComputation* a, const HloComputation* b, + tensorflow::gtl::FlatSet* visited) const { + if (a == b || ContainsKey(*visited, b)) { + // The call graph is guaranteed to be acyclic so any previously visited node + // we encounter was already determined to be dominated. + return true; + } + + const CallGraphNode& b_node = GetNode(b); + if (b_node.callers().empty()) { + // We reached a root node without hitting 'a'. 'a' does not dominate 'b'. + return false; + } + + // Walk up the callers of 'b' until we hit 'a' or a root node (no callers). + visited->insert(b); + for (const HloComputation* b_caller : b_node.callers()) { + if (!DominatesHelper(a, b_caller, visited)) { + return false; + } + } + return true; +} + +bool CallGraph::Dominates(const HloComputation* a, + const HloComputation* b) const { + tensorflow::gtl::FlatSet visited; + return DominatesHelper(a, b, &visited); +} + namespace { // Returns the call context of a computation which is called from contexts 'a' diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 7f9990f06d4fee4c52fa516fc2f6031f5dab2bb9..a3297ff534f429279fd4674517db545f289af627 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -189,6 +189,20 @@ class CallGraph { Status VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes = true) const; + // Returns true if 'a' dominates 'b' in the call graph. Computation 'a' + // dominates computation 'b' iff all callgraph paths in the caller-to-callee + // direction from a root computation to 'b' pass through computation + // 'a'. Trivially, a computation dominates itself. + bool Dominates(const HloComputation* a, const HloComputation* b) const; + + // Returns whether 'instruction' is contained in 'computation' either directly + // ('instruction->parent' is 'computation') or indirectly ('computation' + // dominates 'instruction->parent' in the call graph). + bool InstructionIsNestedIn(const HloInstruction* instruction, + const HloComputation* computation) const { + return Dominates(computation, instruction->parent()); + } + string ToString() const; private: @@ -205,6 +219,13 @@ class CallGraph { const VisitorFunction& visitor_func, const CallGraphNode& node, tensorflow::gtl::FlatSet* visited) const; + // Recursive helper for computing whether 'a' dominates 'b' in the call + // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), + // and 'visited' is the set of computations which have been visited. + bool DominatesHelper( + const HloComputation* a, const HloComputation* b, + tensorflow::gtl::FlatSet* visited) const; + // The HLO module represented by this call graph. const HloModule* module_ = nullptr; diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index e276473c90aa3fcc6b494537db6bceb841ade91e..3c22871b3bff193c27ee2eb639fe72306d532b97 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -81,7 +81,7 @@ class CallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -314,6 +314,37 @@ TEST_F(CallGraphTest, ComplexGraph) { EXPECT_LT(index_of(cond_computation), index_of(a_computation)); EXPECT_LT(index_of(c_computation), index_of(b_computation)); EXPECT_LT(index_of(b_computation), index_of(a_computation)); + + // Verify dominance relations between computation in the graph. + + // Entry dominates everybody, and is dominated by no one except itself. + EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation)); + EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation)); + + // 'a' only dominates 'b' and 'c'. + EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation)); + EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation)); + EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation)); + EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation)); + + EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation)); + + EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation)); + EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation)); + + EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation)); } TEST_F(CallGraphTest, VisitSingletonComputation) { diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 0d1a439724a95231240227cfdf089cb2d74b3dd2..d43dc5b214a95edf3be726b318fd379164edbd9f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -52,17 +52,17 @@ CompileOnlyService::NewService(const ServiceOptions& options) { TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); - std::unique_ptr service( - new CompileOnlyService(compiler, std::move(compute_constant_backend))); + std::unique_ptr service(new CompileOnlyService( + options, compiler, std::move(compute_constant_backend))); return std::move(service); } CompileOnlyService::CompileOnlyService( - Compiler* compiler, std::unique_ptr compute_constant_backend) - : Service(/*backend=*/nullptr, std::move(compute_constant_backend)), - compiler_(compiler) { - runs_in_client_process_ = true; -} + const ServiceOptions& options, Compiler* compiler, + std::unique_ptr compute_constant_backend) + : Service(options, /*backend=*/nullptr, + std::move(compute_constant_backend)), + compiler_(compiler) {} StatusOr>> CompileOnlyService::CompileAheadOfTime( @@ -75,9 +75,11 @@ CompileOnlyService::CompileAheadOfTime( VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); + // TODO(b/63773457): Track DebugOptions in AotCompilationOptions. + DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + // Dump computation proto state if flag is set. - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - const string& directory_path = flags->xla_dump_computations_to; + const string& directory_path = debug_options.xla_dump_computations_to(); if (!directory_path.empty()) { TF_ASSIGN_OR_RETURN( std::unique_ptr session_module, @@ -95,11 +97,10 @@ CompileOnlyService::CompileAheadOfTime( user_computation->ComputeProgramShape(versioned_handle.version)); HloModuleConfig hlo_module_config(*program_shape); - hlo_module_config.set_debug_options( - legacy_flags::GetDebugOptionsFromFlags()); + hlo_module_config.set_debug_options(debug_options); auto* computation_layout = hlo_module_config.mutable_entry_computation_layout(); - if (flags->xla_hlo_profile) { + if (debug_options.xla_hlo_profile()) { hlo_module_config.enable_hlo_profiling(true); } for (int i = 0; i < instance.argument_layouts.size(); ++i) { @@ -122,8 +123,7 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), - MakeHloDumper(), options); + return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index b19f4bd592162045a41e2ec82266826ce84096ef..0a1911cbd15b0278ec2c3ccc944ce4df80a683ed 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -55,7 +55,7 @@ class CompileOnlyService : public Service { // Override Service methods that require or imply the existence of an // execute backend. Note that this does not include TransferToClient, as - // computing contants produces global data that we may wish to transfer. + // computing constants produces global data that we may wish to transfer. tensorflow::Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); @@ -103,7 +103,8 @@ class CompileOnlyService : public Service { private: explicit CompileOnlyService( - Compiler* compiler, std::unique_ptr compute_constant_backend); + const ServiceOptions& options, Compiler* compiler, + std::unique_ptr compute_constant_backend); CompileOnlyService(const CompileOnlyService&) = delete; void operator=(const CompileOnlyService&) = delete; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 7ae285170e4b99ecf036eeb81eaee49ef34034ea..d5bd9214be44f4abd5f672168335ae1a259c9118 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -92,13 +92,6 @@ class AotCompilationOptions { // platform. class Compiler { public: - // Callback signature used to dump the HLO graph during compilation. - // Different compiler backends will call this as they please, providing - // a view of the HLO at different points in compilation -- context for the - // dump is indicated by the label string. - using HloDumper = - std::function; - virtual ~Compiler() {} // Returns the ID of the platform that this compiler targets. @@ -113,21 +106,20 @@ class Compiler { // // Use the overload below to compile computations that run in parallel. virtual StatusOr> Compile( - std::unique_ptr module, HloDumper dump_hlo, + std::unique_ptr module, perftools::gputools::StreamExecutor* executor) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. virtual StatusOr>> Compile( - std::vector> modules, HloDumper dump_hlo, + std::vector> modules, std::vector stream_exec) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> CompileAheadOfTime(std::vector> modules, - HloDumper dump_hlo, const AotCompilationOptions& options) = 0; ///// diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc new file mode 100644 index 0000000000000000000000000000000000000000..cdfa30dd9a7b6a5b9e58087491a9d99caaa1b998 --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -0,0 +1,152 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/computation_placer.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { + proto->set_replica_count(replica_count()); + proto->set_computation_count(computation_count()); + for (int computation = 0; computation < computation_count(); ++computation) { + DeviceAssignmentProto::ComputationDevice* computation_device = + proto->add_computation_devices(); + for (int replica = 0; replica < replica_count(); ++replica) { + computation_device->add_replica_device_ids((*this)(replica, computation)); + } + } + return Status::OK(); +} + +/* static */ StatusOr> +DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { + TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); + auto assignment = MakeUnique(proto.replica_count(), + proto.computation_count()); + for (int computation = 0; computation < proto.computation_count(); + ++computation) { + const auto& computation_device = proto.computation_devices(computation); + TF_RET_CHECK(computation_device.replica_device_ids_size() == + proto.replica_count()); + for (int replica = 0; replica < proto.replica_count(); ++replica) { + (*assignment)(replica, computation) = + computation_device.replica_device_ids(replica); + } + } + return std::move(assignment); +} + +StatusOr ComputationPlacer::DeviceId(int replica, int computation, + int replica_count, + int computation_count) { + TF_RET_CHECK(replica < replica_count); + TF_RET_CHECK(computation < computation_count); + + return computation * replica_count + replica; +} + +StatusOr ComputationPlacer::AssignDevices( + int replica_count, int computation_count) { + DeviceAssignment assignment(replica_count, computation_count); + for (int replica = 0; replica < replica_count; ++replica) { + for (int computation = 0; computation < computation_count; ++computation) { + TF_ASSIGN_OR_RETURN( + int device_id, + DeviceId(replica, computation, replica_count, computation_count)); + assignment(replica, computation) = device_id; + } + } + return std::move(assignment); +} + +/* static */ void ComputationPlacer::RegisterComputationPlacer( + se::Platform::Id platform_id, + ComputationPlacerCreationFunction creation_function) { + tensorflow::mutex_lock lock( + *ComputationPlacer::platform_computation_placer_mutex()); + auto* computation_placers = GetPlatformComputationPlacers(); + CHECK(computation_placers->find(platform_id) == computation_placers->end()); + (*computation_placers)[platform_id].creation_function = creation_function; +} + +/* static */ StatusOr ComputationPlacer::GetForPlatform( + const se::Platform* platform) { + tensorflow::mutex_lock lock( + *ComputationPlacer::platform_computation_placer_mutex()); + auto* computation_placers = GetPlatformComputationPlacers(); + + auto it = computation_placers->find(platform->id()); + if (it == computation_placers->end()) { + return NotFound( + "could not find registered computation placer for platform %s -- check " + "target linkage", + platform->Name().c_str()); + } + + if (it->second.placer == nullptr) { + // Lazily create the computation placer the first time it is needed. + it->second.placer = (*it->second.creation_function)(); + } + + return it->second.placer.get(); +} + +/* static */ tensorflow::mutex* +ComputationPlacer::platform_computation_placer_mutex() { + static tensorflow::mutex* m = new tensorflow::mutex; + return m; +} + +/* static */ std::map* +ComputationPlacer::GetPlatformComputationPlacers() { + static auto* r = + new std::map; + return r; +} + +} // namespace xla + +static std::unique_ptr CreateComputationPlacer() { + return xla::MakeUnique(); +} + +static bool InitModule() { + xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId, + &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId, + &CreateComputationPlacer); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h new file mode 100644 index 0000000000000000000000000000000000000000..7d9abcd100dd9e878da885110bc1bd1ac65e3f84 --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class that represents the device assignment for a set of XLA replicated +// computations. For R replicas and C computations, R * C devices are required +// execute the computation in parallel. The assigned device ids can be accessed +// by assignment(replica, computation). +class DeviceAssignment : public Array2D { + public: + DeviceAssignment() {} + DeviceAssignment(int replica_count, int computation_count) + : Array2D(replica_count, computation_count, -1) { + CHECK_GT(replica_count, 0); + CHECK_GT(computation_count, 0); + } + + int replica_count() const { return height(); } + int computation_count() const { return width(); } + + // Protocol buffer serialization and deserialization. + Status Serialize(DeviceAssignmentProto* proto) const; + + // Return a std::unique_ptr instead of a DeviceAssignment + // directly because one of the supported TF platforms (mac) does not compile + // due to a StatusOr of an incomplete type (DeviceAssignment). + static StatusOr> Deserialize( + const DeviceAssignmentProto& proto); +}; + +// A generic implementation of the XLA computation placer, which assigns device +// ids to a set of replicated computations. +class ComputationPlacer { + public: + ComputationPlacer() {} + virtual ~ComputationPlacer() {} + + // Returns the device id assigned to the given replica and computation + // instance for [replica_count x computation_count] setup. The returned device + // id must match the assignement from PlaceReplicatedComputation(). + virtual StatusOr DeviceId(int replica, int computation, + int replica_count, int computation_count); + + // Returns the device ids assigned to a set of replicated computations, given + // the number of replicas and the number of computations. + virtual StatusOr AssignDevices(int replica_count, + int computation_count); + + using ComputationPlacerCreationFunction = + std::unique_ptr (*)(); + + // Registers a computation placer creation function for a particular platform. + static void RegisterComputationPlacer( + perftools::gputools::Platform::Id platform_id, + ComputationPlacerCreationFunction creation_function); + + // Returns the computation placer singleton pointer if it is available for the + // given platform, or an error status if it is not. + static StatusOr GetForPlatform( + const perftools::gputools::Platform* platform); + + private: + // Routine that returns the mutex that guards the platform-to-computation + // placer map. Done as a routine to ensure correct initialization ordering, + // since RegisterComputationPlacer can be called during program initialization + // time. + static tensorflow::mutex* platform_computation_placer_mutex(); + + // State kept for each kind of ComputationPlacer. Registration functions set + // up creation_function, and then we use that to lazily create "placer" the + // first time GetForPlatform is invoked for a particular id. + struct State { + std::unique_ptr placer; + ComputationPlacerCreationFunction creation_function = nullptr; + }; + + // Map from platform kind to computation placer singleton. + static std::map* + GetPlatformComputationPlacers(); + + perftools::gputools::Platform::Id platform_id_; + + TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc index 9aa32a1fb76616e6c81043fabb053570a86d2619..70e25eebdb068db893e24aec0f72d09090ac7027 100644 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ b/tensorflow/compiler/xla/service/computation_tracker.cc @@ -216,6 +216,7 @@ StatusOr> ComputationTracker::BuildHloModule( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_computation, computation->BuildHloComputation(versioned_handle.version, resolver, + config.debug_options(), include_unreachable_instructions)); // Add the newly created computation to VersionedHandle-to-HloComputation diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index a3803c34ba7db99db7139b53cad22d5bce7fe5e6..c47abe9c62a40716eb03fbd2213b941b5e0abbc3 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -551,6 +551,9 @@ StatusOr CopyInsertion::Run(HloModule* module) { // Add copies of computation root instructions, if needed. FlatMap> while_body_read_only_indices; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } VLOG(2) << "computation " << computation->name(); InstructionCopier root_copier(computation->root_instruction(), /*copy_users=*/{}); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index cc77339bb63220d8c9da0500ee818c7b9fb02a4b..026be75757a9129c94e2c1c3083f226790d482f4 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -87,7 +87,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { TEST_F(CopyInsertionTest, SingleConstant) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); @@ -110,9 +110,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -140,11 +140,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { // the computation result. Verify that copies are added properly. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -152,7 +152,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction::CreateTuple({constant3, constant2})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -196,9 +196,8 @@ TEST_F(CopyInsertionTest, BitcastConstant) { // The output of a bitcast is its operand (same buffer), so a bitcast // constant feeding the result must have a copy added. auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 42.0}))); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1.0, 42.0}))); HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); @@ -308,9 +307,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { // copy is added. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -318,7 +317,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction::CreateTuple({constant2, constant1})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); HloInstruction* gte = @@ -350,7 +349,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + HloInstruction::CreateConstant(Literal::CreateR0(10))); const Shape& loop_state_shape = nested ? nested_loop_state_shape_ : loop_state_shape_; auto loop_state = builder.AddInstruction( @@ -381,7 +380,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(1). @@ -419,7 +418,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -488,7 +487,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); @@ -503,9 +502,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); } - auto update = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); @@ -538,7 +536,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( gte0->shape(), HloOpcode::kAdd, gte0, inc)); @@ -548,9 +546,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // GTE(GTE(loop_state, 1), 0) -> Add auto gte10 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); - auto update10 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update10 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, gte10, update10)); @@ -574,11 +571,10 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".While"); auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); if (nested) { auto inner_init = builder.AddInstruction( @@ -601,9 +597,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToConstant() { auto builder = HloComputation::Builder(TestName() + ".While"); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, &builder); } @@ -620,11 +615,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto v1 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto v2 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -632,7 +627,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -644,7 +639,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto one_vec = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto data_init = @@ -657,12 +652,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto data_init = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); - auto one_vec = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); @@ -677,7 +671,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const bool nested = ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(nested)); auto body = module_->AddEmbeddedComputation( diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 51ecbccd494fced68d5e92eda752f5292580a190..2ca4af67cd55cfd01e952cf2306d5e475d7f4944 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -52,7 +52,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -69,9 +68,11 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", + "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:inliner", + "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep @@ -151,9 +152,12 @@ cc_library( cc_library( name = "parallel_cpu_executable", srcs = ["parallel_cpu_executable.cc"], - hdrs = ["parallel_cpu_executable.h"], + hdrs = [ + "parallel_cpu_executable.h", + ], deps = [ ":cpu_runtime", + ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -177,7 +181,9 @@ cc_library( cc_library( name = "ir_emitter", srcs = ["ir_emitter.cc"], - hdrs = ["ir_emitter.h"], + hdrs = [ + "ir_emitter.h", + ], deps = [ ":cpu_runtime", ":dot_op_emitter", @@ -191,7 +197,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", @@ -222,8 +227,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", @@ -283,8 +288,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:analysis", @@ -325,15 +328,20 @@ cc_library( name = "cpu_runtime", srcs = [ "cpu_runtime.cc", - "infeed_manager.cc", + "xfeed_manager.cc", ], hdrs = [ "cpu_runtime.h", - "infeed_manager.h", + "xfeed_manager.h", ], copts = runtime_copts(), deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", ], ) @@ -405,6 +413,7 @@ cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -413,9 +422,9 @@ cc_test( ) cc_test( - name = "infeed_manager_test", + name = "xfeed_manager_test", size = "small", - srcs = ["infeed_manager_test.cc"], + srcs = ["xfeed_manager_test.cc"], deps = [ ":cpu_runtime", "//tensorflow/core:lib", @@ -437,10 +446,16 @@ cc_library( cc_library( name = "cpu_parallelization_preparation", srcs = ["cpu_parallelization_preparation.cc"], - hdrs = ["cpu_parallelization_preparation.h"], + hdrs = [ + "cpu_parallelization_preparation.h", + ], deps = [ + ":ir_emission_utils", + ":shape_partition", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", @@ -472,7 +487,6 @@ cc_library( ":cpu_runtime", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:hlo", ], ) @@ -499,9 +513,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", ], ) @@ -511,6 +525,7 @@ cc_test( srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", @@ -518,6 +533,26 @@ cc_test( ], ) +cc_library( + name = "shape_partition", + srcs = ["shape_partition.cc"], + hdrs = ["shape_partition.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + ], +) + +cc_test( + name = "shape_partition_test", + srcs = ["shape_partition_test.cc"], + deps = [ + ":shape_partition", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 8ebf9ab110d080a017abb2077ac588672c8099bb..d86881c282488356f5c146467e5e41ecc5038511 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -35,8 +35,6 @@ limitations under the License. #include "external/llvm/include/llvm/Transforms/IPO.h" #include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h" #include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h" -#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" @@ -45,7 +43,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -66,14 +63,9 @@ operator()(llvm::Module& module) const { VLOG(2) << "IR before optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); - legacy_flags::CompilerFunctorFlags* flags = - legacy_flags::GetCompilerFunctorFlags(); - string dump_path = flags->xla_debug_cpu_dump_ir; - if (!dump_path.empty()) { - std::unique_ptr f; - TF_CHECK_OK(tensorflow::Env::Default()->NewAppendableFile(dump_path, &f)); - TF_CHECK_OK(f->Append(llvm_ir::DumpModuleToString(module))); - TF_CHECK_OK(f->Close()); + + if (pre_optimization_hook_) { + TF_CHECK_OK(pre_optimization_hook_(module)); } // Build up optimization pipeline. @@ -99,6 +91,10 @@ operator()(llvm::Module& module) const { VLOG(2) << "IR after optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); + if (post_optimization_hook_) { + TF_CHECK_OK(post_optimization_hook_(module)); + } + // Generate code. llvm::MCContext* mc_context; llvm::legacy::PassManager codegen_passes; @@ -135,33 +131,28 @@ std::vector VectorFunctionsForTargetLibraryInfoImpl( std::vector vector_functions; const llvm::VecDesc four_wide_vector_functions[] = { - {"expf", runtime::kExpV4F32, 4}, - {"llvm.exp.f32", runtime::kExpV4F32, 4}, + {"expf", runtime::kExpV4F32SymbolName, 4}, + {"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4}, - {"logf", runtime::kLogV4F32, 4}, - {"llvm.log.f32", runtime::kLogV4F32, 4}, + {"logf", runtime::kLogV4F32SymbolName, 4}, + {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4}, - {"tanhf", runtime::kTanhV4F32, 4}, - {"llvm.tanh.f32", runtime::kTanhV4F32, 4}, + {"tanhf", runtime::kTanhV4F32SymbolName, 4}, + {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4}, }; const llvm::VecDesc eight_wide_vector_functions[] = { - {"expf", runtime::kExpV8F32, 8}, - {"llvm.exp.f32", runtime::kExpV8F32, 8}, + {"expf", runtime::kExpV8F32SymbolName, 8}, + {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8}, - {"logf", runtime::kLogV8F32, 8}, - {"llvm.log.f32", runtime::kLogV8F32, 8}, + {"logf", runtime::kLogV8F32SymbolName, 8}, + {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8}, - {"tanhf", runtime::kTanhV8F32, 8}, - {"llvm.tanh.f32", runtime::kTanhV8F32, 8}, + {"tanhf", runtime::kTanhV8F32SymbolName, 8}, + {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8}, }; - // Our vectorized library calls are currently implement by calling into Eigen. - // As such, only emit calls to these routines if --xla_cpu_use_eigen is - // enabled. - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - if (flags->xla_cpu_use_eigen && - (arch == llvm::Triple::x86 || llvm::Triple::x86_64)) { + if (arch == llvm::Triple::x86 || llvm::Triple::x86_64) { llvm::SmallVector features; feature_string.split(features, ',', -1, /*KeepEmpty=*/false); if (std::find(features.begin(), features.end(), "+sse4.1") != diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 17dadebe975b936b7d5d7a78ac69b890d9c8e7ac..a5358076b7f543948e0957767dfda1be43e07611 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -39,13 +39,22 @@ class CompilerFunctor { // Returns a VectorIntrinsics where all intrinsics are available. static VectorIntrinsics AllIntrinsics(); + // A callback of this type can be run before and/or after IR-level + // optimization to e.g. dump out the generated IR to disk or gather some + // statistics. + using ModuleHook = std::function; + explicit CompilerFunctor(llvm::TargetMachine* target_machine, const Disassembler* disassembler, int opt_level, - const VectorIntrinsics& available_intrinsics) + const VectorIntrinsics& available_intrinsics, + ModuleHook pre_optimization_hook = nullptr, + ModuleHook post_optimization_hook = nullptr) : target_machine_(target_machine), disassembler_(CHECK_NOTNULL(disassembler)), opt_level_(opt_level), - available_intrinsics_(available_intrinsics) {} + available_intrinsics_(available_intrinsics), + pre_optimization_hook_(pre_optimization_hook), + post_optimization_hook_(post_optimization_hook) {} // Compile a Module to an ObjectFile. llvm::object::OwningBinary operator()( @@ -61,6 +70,8 @@ class CompilerFunctor { const Disassembler* disassembler_; const unsigned opt_level_; const VectorIntrinsics available_intrinsics_; + ModuleHook pre_optimization_hook_; + ModuleHook post_optimization_hook_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index cdf43587b683e4e22d14d4fc08fa3705bc636de8..069979c6611e90ed2d95cbbe341198577cdf56cf 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -30,11 +29,6 @@ namespace xla { namespace cpu { StatusOr ConvCanonicalization::Run(HloModule* module) { - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - if (!flags->xla_cpu_use_eigen) { - return false; - } - bool changed = false; for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index f5ad431277d94039cd20cf51e0932413e87a0436..ec992f15e63b29ee67d16b6d841fedffd9c90f5b 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -59,11 +59,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kInputFeatureCount, kBatchSize, kInputSize, kInputSize)))); // The kernel dimensions are in OIHW order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; @@ -113,11 +113,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in NHWC order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kBatchSize, kInputSize, kInputSize, kInputFeatureCount)))); // The kernel dimensions are in HWIO order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 34b99f2440b935402283d76d4a09475f4bfcb315..f115236ee71125e23341c905c01eb39fd77cb210 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -37,7 +37,6 @@ limitations under the License. #include "external/llvm/include/llvm/Support/TargetSelect.h" #include "external/llvm/include/llvm/Target/TargetMachine.h" #include "external/llvm/include/llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -70,10 +69,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -81,7 +82,10 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" namespace se = ::perftools::gputools; @@ -166,7 +170,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (it.first != kXlaParallelCpuOption) { + if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { @@ -245,19 +249,23 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { +Status CpuCompiler::RunHloPasses(HloModule* module) { // Optimization pipeline. - HloPassPipeline pipeline("CPU", dump_hlo); + HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(); + ReducePrecisionInsertion::AddPasses( + &pipeline, module->config().debug_options(), + HloReducePrecisionOptions::BEFORE_OP_FUSION); + // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding // where we will take this pass in future. // pipeline.AddPass(); pipeline.AddPass(); { - auto& pass = pipeline.AddPass>("simplification", - dump_hlo); + auto& pass = + pipeline.AddPass>("simplification"); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -275,6 +283,11 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + + ReducePrecisionInsertion::AddPasses( + &pipeline, module->config().debug_options(), + HloReducePrecisionOptions::AFTER_OP_FUSION); + pipeline.AddPass( module->mutable_entry_computation_layout()); // The LayoutAssignment pass may leave behind kCopy instructions which are @@ -285,8 +298,13 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { /*enable_dot_simplification=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); // Outline ops in the entry computation into calls to subcomputations. + const int max_parallelism = + module->config().intra_op_parallelism_threads() > 0 + ? module->config().intra_op_parallelism_threads() + : tensorflow::port::NumSchedulableCPUs(); if (CpuParallelBackendRequested(module->config())) { - pipeline.AddPass(); + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction()); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -299,7 +317,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { if (CpuParallelBackendRequested(module->config())) { // Re-run the outlining, in case any copies were inserted into the entry // computation. - pipeline.AddPass(); + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction()); } pipeline.AddPass(); pipeline.AddPass(); @@ -310,6 +329,7 @@ namespace { // Align buffers to 16-byte boundaries. constexpr int64 kMemoryAlignment = 16; +auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; }; llvm::TargetOptions CompilerTargetOptions( const HloModuleConfig& module_config) { @@ -338,25 +358,83 @@ llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) { } } +Status AppendIRToFile(const string& file_name, const string& ir_module_string) { + std::unique_ptr f; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->NewAppendableFile(file_name, &f)); + TF_RETURN_IF_ERROR(f->Append(ir_module_string)); + TF_RETURN_IF_ERROR(f->Close()); + return Status::OK(); +} + +Status InitializeIRDumpHooks( + const HloModule& module, + CompilerFunctor::ModuleHook* pre_optimization_ir_dump_hook, + CompilerFunctor::ModuleHook* post_optimization_ir_dump_hook) { + const string& dump_ir_to = module.config().debug_options().xla_dump_ir_to(); + if (dump_ir_to.empty()) { + return Status::OK(); + } + + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->RecursivelyCreateDir(dump_ir_to)); + + string safe_file_name_base = module.name(); + std::replace_if(safe_file_name_base.begin(), safe_file_name_base.end(), + [](char c) { return c == '/' || c == '\\'; }, '_'); + + string unoptimized_ir_file_name = tensorflow::io::JoinPath( + dump_ir_to, + tensorflow::strings::StrCat("ir-", safe_file_name_base, "-no-opt.ll")); + string optimized_ir_file_name = tensorflow::io::JoinPath( + dump_ir_to, + tensorflow::strings::StrCat("ir-", safe_file_name_base, "-opt.ll")); + + // We still want to append to avoid overwriting possibly important information + // due to operator error. + + *pre_optimization_ir_dump_hook = + [unoptimized_ir_file_name](const llvm::Module& module) { + return AppendIRToFile(unoptimized_ir_file_name, + llvm_ir::DumpModuleToString(module)); + }; + + *post_optimization_ir_dump_hook = + [optimized_ir_file_name](const llvm::Module& module) { + return AppendIRToFile(optimized_ir_file_name, + llvm_ir::DumpModuleToString(module)); + }; + + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::Compile( - std::unique_ptr module, HloDumper dump_hlo, - se::StreamExecutor* stream_exec) { + std::unique_ptr module, se::StreamExecutor* stream_exec) { + VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, &InitializeLLVMCommandLineOptions, module->config()); + CompilerFunctor::ModuleHook pre_optimization_ir_dump_hook; + CompilerFunctor::ModuleHook post_optimization_ir_dump_hook; + TF_RETURN_IF_ERROR(InitializeIRDumpHooks(*module, + &pre_optimization_ir_dump_hook, + &post_optimization_ir_dump_hook)); + // Compile must be thread-safe so create a new LLVM context for the module. auto llvm_context = MakeUnique(); auto llvm_module = MakeUnique("__compute_module", *llvm_context); auto jit = MakeUnique(CompilerTargetOptions(module->config()), - CodeGenOptLevel(module->config())); + CodeGenOptLevel(module->config()), + pre_optimization_ir_dump_hook, + post_optimization_ir_dump_hook); llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), dump_hlo)); + TF_RETURN_IF_ERROR(RunHloPasses(module.get())); HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; @@ -367,8 +445,17 @@ StatusOr> CpuCompiler::Compile( } std::unique_ptr cpu_executable; - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); + + // Cache these flags here since we'll want to access them after the module's + // ownership is std::moved. + const bool embed_ir_in_executable = + module->config().debug_options().xla_embed_ir_in_executable(); + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (CpuParallelBackendRequested(module->config())) { + VLOG(1) << "Using parallel cpu backend"; + // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. // DependencyHloOrdering is used for the parallel emitter because the order @@ -379,12 +466,12 @@ StatusOr> CpuCompiler::Compile( std::unique_ptr assignment, BufferAssigner::Run(module.get(), MakeUnique(module.get()), - BufferSizeBytesFunction(), kMemoryAlignment)); + BufferSizeBytesFunction(), memory_alignment)); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -400,7 +487,7 @@ StatusOr> CpuCompiler::Compile( if (instruction->opcode() == HloOpcode::kConstant) { // Copy the constant out of the ProtocolBuffer so that we can give it a // higher alignment. - const void* data = LiteralUtil::InternalData(instruction->literal()); + const void* data = instruction->literal().InternalData(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( instruction, MakeUnique(size)); @@ -418,11 +505,15 @@ StatusOr> CpuCompiler::Compile( } IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx); + &hlo_to_profile_idx, jit->target_machine()); + std::unique_ptr> function_names( new std::map()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { + if (embedded_computation->IsFusionComputation()) { + continue; + } auto parallel_computation_iter = parallel_computations.find(embedded_computation); // All parallel computations are considered to be an entry computation for @@ -446,7 +537,7 @@ StatusOr> CpuCompiler::Compile( } string ir_module_string; - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } @@ -457,11 +548,13 @@ StatusOr> CpuCompiler::Compile( std::move(function_names), std::move(hlo_to_profile_idx), std::move(aligned_constants))); - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { static_cast(*cpu_executable) .set_ir_module_string(ir_module_string); } } else { + VLOG(1) << "Using sequential cpu backend"; + // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). @@ -476,12 +569,12 @@ StatusOr> CpuCompiler::Compile( BufferAssigner::Run( module.get(), MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), kMemoryAlignment)); + BufferSizeBytesFunction(), memory_alignment)); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } // Each computation is a single function. Emit all embedded computations @@ -489,9 +582,13 @@ StatusOr> CpuCompiler::Compile( // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx); + &hlo_to_profile_idx, jit->target_machine()); + for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { + if (embedded_computation->IsFusionComputation()) { + continue; + } TF_RETURN_IF_ERROR( ir_emitter .EmitComputation(embedded_computation, @@ -510,7 +607,7 @@ StatusOr> CpuCompiler::Compile( string function_name = llvm_ir::AsString(entry_function->getName()); string ir_module_string; - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } @@ -520,17 +617,18 @@ StatusOr> CpuCompiler::Compile( std::move(jit), std::move(assignment), std::move(module), function_name, std::move(hlo_to_profile_idx))); - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { static_cast(*cpu_executable) .set_ir_module_string(ir_module_string); } } + VLOG(1) << "Compilation finished"; return std::move(cpu_executable); } StatusOr>> CpuCompiler::Compile( - std::vector> modules, HloDumper dump_hlos, + std::vector> modules, std::vector stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on CPU."); @@ -538,7 +636,6 @@ StatusOr>> CpuCompiler::Compile( StatusOr>> CpuCompiler::CompileAheadOfTime(std::vector> modules, - HloDumper dump_hlo, const AotCompilationOptions& aot_options) { TF_RET_CHECK(!modules.empty()); std::call_once(llvm_command_line_options_initialized, @@ -627,8 +724,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::vector> results; for (size_t i = 0; i < modules.size(); ++i) { HloModule* module = modules[i].get(); + VLOG(1) << "Compiling ahead-of-time: " << module->name(); - TF_RETURN_IF_ERROR(RunHloPasses(module, dump_hlo)); + TF_RETURN_IF_ERROR(RunHloPasses(module)); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, @@ -640,20 +738,24 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::unique_ptr assignment, BufferAssigner::Run( module, MakeUnique(module, module_sequence), - BufferSizeBytesFunction(), kMemoryAlignment)); + BufferSizeBytesFunction(), memory_alignment)); - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/nullptr); + /*hlo_to_profile_idx=*/nullptr, target_machine.get()); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { + if (embedded_computation->IsFusionComputation()) { + continue; + } TF_RETURN_IF_ERROR( ir_emitter .EmitComputation(embedded_computation, @@ -671,10 +773,17 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, entry_function->setName(llvm_ir::AsStringRef(entry_point_name)); + CompilerFunctor::ModuleHook pre_optimization_ir_dump_hook; + CompilerFunctor::ModuleHook post_optimization_ir_dump_hook; + TF_RETURN_IF_ERROR(InitializeIRDumpHooks(*module, + &pre_optimization_ir_dump_hook, + &post_optimization_ir_dump_hook)); + Disassembler disassembler(*target_machine); - CompilerFunctor compiler_functor(target_machine.get(), &disassembler, - opt_level, - CompilerFunctor::AllIntrinsics()); + CompilerFunctor compiler_functor( + target_machine.get(), &disassembler, opt_level, + CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook, + post_optimization_ir_dump_hook); llvm::object::OwningBinary object_file = compiler_functor(llvm_module); llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); @@ -704,6 +813,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::move(object_file_data), std::move(buffer_sizes), result_slice.index())); } + + VLOG(1) << "Compilation finished"; return std::move(results); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 29fa4eac61beaa25e1662b1be5afa9757ab077ea..b82e181df2b883ddac7e7d39212fb28b07ca7b0c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -110,16 +110,15 @@ class CpuCompiler : public Compiler { ~CpuCompiler() override {} StatusOr> Compile( - std::unique_ptr module, HloDumper dump_hlo, + std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> modules, HloDumper dump_hlo, + std::vector> modules, std::vector stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> modules, - HloDumper dump_hlo, const AotCompilationOptions& options) override; perftools::gputools::Platform::Id PlatformId() const override; @@ -132,7 +131,7 @@ class CpuCompiler : public Compiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* hlo_module, HloDumper dump_hlo); + Status RunHloPasses(HloModule* module); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 671d6957a39c068c416cd4fa3739f05c9ddb3baa..8787336ed0755647dd9ffbc68484d4cb9cef4790 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -66,7 +66,8 @@ CpuExecutable::CpuExecutable( CHECK(sym) << "Symbol " << entry_function_name << " not found."; // getAddress can do work under the hood in the jit, so it needs to be // guarded by the mutex. - compute_function_ = reinterpret_cast(sym.getAddress()); + compute_function_ = + reinterpret_cast(cantFail(sym.getAddress())); } // Given a pointer to an output buffer (following the CPU JIT calling diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index b5746769ba7e4bb2593bab7abc24f1a75a083d80..4b42530c09dbf2ff4aa767e398535e4d55cc4673 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -78,6 +78,11 @@ class CpuExecutable : public Executable { ir_module_string_ = ir_module_string; } + const Status EqualOrFail(const Executable& executable) { + // TODO(b/62952745) Implement equality test on CPU executable. + return Unimplemented("Equality test on CPU executable is not implemented."); + } + static int64 ShapeSizeBytes(const Shape& shape); private: diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index f6b1dcae75a773811f8c652dea36b7f3ca36e901..4d0e0f744ac4b02f7c4a74c5a341d6b9ce937967 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -15,19 +15,28 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace cpu { StatusOr ParallelizationPreparation::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ParallelizationPreparation ENTRY"); + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; + TF_ASSIGN_OR_RETURN(changed, RunParallelTaskAssignment(module)); + HloComputation* entry_computation = module->entry_computation(); std::unordered_set outlined; std::vector instructions_to_outline; @@ -44,13 +53,21 @@ StatusOr ParallelizationPreparation::Run(HloModule* module) { instruction->opcode() == HloOpcode::kConstant) { continue; } + + // Outline 'instruction' in isolation if it was assigned parallel tasks. + if (OutlineParallelizableInstruction(instruction)) { + outlined.insert(instruction); + changed = true; + continue; + } + instructions_to_outline.clear(); HloInstruction* outline_candidate = instruction; instructions_to_outline.push_back(outline_candidate); bool all_bitcasts = outline_candidate->opcode() == HloOpcode::kBitcast; // Outline sole users with the current instruction. - while (outline_candidate->users().size() == 1) { + while (CanOutlineWithUser(outline_candidate)) { HloInstruction* prior_candidate = outline_candidate; outline_candidate = *outline_candidate->users().begin(); all_bitcasts |= outline_candidate->opcode() == HloOpcode::kBitcast; @@ -108,6 +125,9 @@ StatusOr ParallelizationPreparation::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } HloInstruction* root = computation->root_instruction(); // Copy root instruction if it does not define its own top-level buffer. // TODO(b/32885001) Remove these copies (at least for the unambiguous case). @@ -120,8 +140,136 @@ StatusOr ParallelizationPreparation::Run(HloModule* module) { changed = true; } } + + XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT"); + XLA_VLOG_LINES(2, module->ToString()); return changed; } +StatusOr ParallelizationPreparation::RunParallelTaskAssignment( + HloModule* module) { + VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_; + bool changed = false; + // Run cost analysis on entry computation. + HloCostAnalysis cost_analysis(shape_size_); + HloComputation* computation = module->entry_computation(); + Status cost_status = computation->root_instruction()->Accept(&cost_analysis); + for (auto& instruction : computation->instructions()) { + // Currently, we do not assign parallel tasks to instructions with at least + // one of the following properties: + // *) Internal threading (library calls to kConv, kDot, and kCustomCall). + // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Tuple-shaped. + // TODO(b/27458679) Parallelize instructions which are skipped here. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kCustomCall || + instruction->opcode() == HloOpcode::kSelectAndScatter || + (instruction->opcode() == HloOpcode::kConvolution && + PotentiallyImplementedAsEigenConvolution(*instruction)) || + PotentiallyImplementedAsEigenDot(*instruction) || + (instruction->opcode() == HloOpcode::kFusion && + instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || + ShapeUtil::IsTuple(instruction->shape())) { + continue; + } + + // Calculate target parallel task count in [1, max_parallelism_]. + const int64 target_parallel_task_count = GetTargetParallelTaskCount( + cost_status.ok() ? &cost_analysis : nullptr, instruction.get()); + if (target_parallel_task_count == 1) { + continue; + } + + // Assign feasible dimension partitions (based on actual dimension sizes). + auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) + .Run(target_parallel_task_count); + const int64 total_partition_count = + ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); + if (total_partition_count <= 1) { + // Feasible partition calculation resulting in no partitioning, so skip. + continue; + } + VLOG(2) << "Assigning parallel task count: " << total_partition_count + << " to instruction: " << instruction->name(); + // Map 'instruction' to assigned dimension partitioning. + instruction->set_outer_dimension_partitions(dim_partition_counts); + } + + return changed; +} + +int64 ParallelizationPreparation::GetTargetParallelTaskCount( + const HloCostAnalysis* cost_analysis, HloInstruction* instruction) { + // Default to a simple cost model based on hlo size and typical L2 cache size. + // Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an + // error status (likely because HLOs like CustomCall are not yet implemented + // in the HloCostAnalysis). + int64 instruction_cost = shape_size_(instruction->shape()); + int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + if (cost_analysis != nullptr) { + // Calculate the instruction cost in cycles. + // TODO(29630486) Improve on this linear cost model. + // Consider making 'min_cost_per_thread' be a function of the target + // bandwidth limit for instructions with low arithmetic complexity. + instruction_cost = 1 * cost_analysis->flop_count(*instruction) + + 2 * cost_analysis->transcendental_count(*instruction) + + 10 * cost_analysis->bytes_accessed(*instruction); + // Minimum per-thread cost is 100us of work on a 2GHz core. + min_cost_per_thread = 100000; + } + // Return target parallel task count in [1, max_parallelism_]. + return std::min(max_parallelism_, + std::max(1LL, instruction_cost / min_cost_per_thread)); +} + +bool ParallelizationPreparation::OutlineParallelizableInstruction( + HloInstruction* instruction) { + if (instruction->outer_dimension_partitions().empty()) { + return false; + } + // Store dimension partition counts before outlining (which clones + // 'instruction'). + std::vector dim_partition_counts = + instruction->outer_dimension_partitions(); + // Outline 'instruction' in its own sub-computation. + HloModule* module = instruction->parent()->parent(); + auto* call = module->OutlineExpressionFromComputation( + {instruction}, tensorflow::strings::StrCat("pp_", instruction->name()), + module->entry_computation()); + // Map previously assigned 'dim_partition_counts' to cloned root instruction. + VLOG(1) << "Outlining parallelizable" + << " caller: " << call->name() + << " callee: " << call->to_apply()->root_instruction()->name(); + call->to_apply()->root_instruction()->set_outer_dimension_partitions( + dim_partition_counts); + return true; +} + +bool ParallelizationPreparation::CanOutlineWithUser( + HloInstruction* instruction) { + if (instruction->users().size() != 1) { + // Do not outline 'instruction' with multiple users. + return false; + } + if (AssignedParallelTasks(instruction) || + AssignedParallelTasks(*instruction->users().begin())) { + // Do not outline if 'instruction' (or user) were assigned parallel tasks. + return false; + } + return true; +} + +bool ParallelizationPreparation::AssignedParallelTasks( + HloInstruction* instruction) { + return !instruction->outer_dimension_partitions().empty() || + (instruction->opcode() == HloOpcode::kCall && + !instruction->to_apply() + ->root_instruction() + ->outer_dimension_partitions() + .empty()); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h index 62999f5686db2e4db3ace0c5580bd156edbfa994..d53fc461509cad51778dba37922212731236952f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,14 +33,51 @@ namespace cpu { // handle While constructs. class ParallelizationPreparation : public HloPassInterface { public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + ParallelizationPreparation( + const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size) + : max_parallelism_(max_parallelism), shape_size_(shape_size) {} ~ParallelizationPreparation() override {} + tensorflow::StringPiece name() const override { return "cpu-parallel-prepare"; } - // Run instruction fusion on the given computation. Returns whether the + // Run parallel preparation on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // Assigns parallel task partitions to conformant instructions in 'module'. + // Returns true on success or error status otherwise. + StatusOr RunParallelTaskAssignment(HloModule* module); + + // Returns the target parallel task count for 'instruction'. + // Utilizes 'cost_analysis' if non-null. + // Otherwise defaults to a simple HLO output size-based cost model. + int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis, + HloInstruction* instruction); + + // Outlines 'instruction' from entry computation, if it had + // been assigned parallel tasks in an earlier pass through the computation. + // Returns true if 'instruction' was successfully outlined, false otherwise. + bool OutlineParallelizableInstruction(HloInstruction* instruction); + + // Returns true if 'instruction' can be outlined into the same sub-computation + // with its single user (parallelizable instructions are not outlined with + // each other). Returns false otherwise. + bool CanOutlineWithUser(HloInstruction* instruction); + + // Returns true if 'instruction' (or the root of the sub-computation that + // 'instruction' calls) has had parallel tasks assigned in earlier pass. + // Returns false otherwise. + bool AssignedParallelTasks(HloInstruction* instruction); + + const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 253de20f25127bf0ac23d5969e0f16c143396e47..5d6efa535958a2757a22f633aa41d08ca712cb5d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -17,35 +17,118 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace xla { namespace cpu { namespace runtime { -InfeedManager* GetInfeedManager() { - static InfeedManager* manager = new InfeedManager; +XfeedManager* GetXfeedManager() { + static XfeedManager* manager = new XfeedManager; return manager; } +extern const char* const kEigenMatMulF32SymbolName = + "__xla_cpu_runtime_EigenMatMulF32"; +extern const char* const kEigenMatMulF64SymbolName = + "__xla_cpu_runtime_EigenMatMulF64"; +extern const char* const kEigenConvF32SymbolName = + "__xla_cpu_runtime_EigenConvF32"; +extern const char* const kEigenSingleThreadedMatMulF32SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; +extern const char* const kEigenSingleThreadedMatMulF64SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF64"; +extern const char* const kEigenSingleThreadedConvF32SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedConvF32"; +extern const char* const kAcquireInfeedBufferForDequeueSymbolName = + "__xla_cpu_runtime_AcquireInfeedBufferForDequeue"; +extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName = + "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue"; +extern const char* const kAcquireOutfeedBufferForPopulationSymbolName = + "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation"; +extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = + "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; +extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu } // namespace xla -void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - xla::int32 buffer_length) { - xla::cpu::runtime::InfeedManager* infeed = - xla::cpu::runtime::GetInfeedManager(); +namespace { + +tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { + xla::StatusOr shape = + xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); + if (shape.ok()) { + return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie()); + } + return ""; +} + +} // namespace + +void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, + const void* shape, + xla::int32 shape_length) { + if (VLOG_IS_ON(2)) { + LOG(INFO) << "AcquireInfeedBufferForDequeue: " + << ShapeString(shape, shape_length); + } + xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); + // Wait until there's a buffer to dequeue. + xla::cpu::runtime::XfeedBuffer* buffer = + xfeed->infeed()->BlockingDequeueBuffer(); + CHECK_EQ(buffer->length(), buffer_length) + << "XLA program infeed request buffer size " << buffer_length + << " did not match the runtime's infed buffer length " << buffer->length() + << "; program reports desired shape: " + << ShapeString(shape, shape_length); + return buffer->data(); +} + +void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, + xla::int32 shape_length) { + if (VLOG_IS_ON(2)) { + LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " + << ShapeString(shape_ptr, shape_length); + } + xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); + xla::StatusOr shape = + xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); + xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, + std::move(shape)); +} + +void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length) { + if (VLOG_IS_ON(2)) { + LOG(INFO) << "AcquireOutfeedBufferForPopulation: " + << ShapeString(shape_ptr, shape_length); + } + xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); // Wait until there's a buffer to dequeue. - xla::cpu::runtime::InfeedBuffer* buffer = infeed->BlockingDequeueBuffer(); - CHECK_EQ(buffer->length(), buffer_length); + xla::cpu::runtime::XfeedBuffer* buffer = + xfeed->outfeed()->BlockingDequeueBuffer(); + CHECK_EQ(buffer->length(), buffer_length) + << "XLA program outfeed request buffer size " << buffer_length + << " did not match the runtime's outfeed buffer length " + << buffer->length() << "; program reports outfed shape: " + << ShapeString(shape_ptr, shape_length); return buffer->data(); } -void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, - void* buffer_ptr) { - xla::cpu::runtime::InfeedManager* infeed = - xla::cpu::runtime::GetInfeedManager(); - infeed->ReleaseCurrentBuffer(buffer_length, buffer_ptr); +void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, + xla::int32 shape_length) { + if (VLOG_IS_ON(2)) { + LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " + << ShapeString(shape_ptr, shape_length); + } + xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); + xla::StatusOr shape = + xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); + xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, shape); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 8eae2102305a3898c244a356d383184139e9208e..29feb7267fe97f6876827b6cbfa6217a0cecf238 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -26,7 +26,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -41,22 +41,23 @@ namespace runtime { // the actual symbol. // 2. When using ahead-of-time compilation, the linker can resolve the name // because it is a symbol in the cpu_runtime library. -constexpr char kEigenMatmulF32SymbolName[] = "__xla_cpu_runtime_EigenMatMulF32"; -constexpr char kEigenMatmulF64SymbolName[] = "__xla_cpu_runtime_EigenMatMulF64"; -constexpr char kEigenConvF32SymbolName[] = "__xla_cpu_runtime_EigenConvF32"; -constexpr char kEigenSingleThreadedMatmulF32SymbolName[] = - "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; -constexpr char kEigenSingleThreadedMatmulF64SymbolName[] = - "__xla_cpu_runtime_EigenSingleThreadedMatMulF64"; -constexpr char kEigenSingleThreadedConvF32SymbolName[] = - "__xla_cpu_runtime_EigenSingleThreadedConvF32"; -constexpr char kAcquireInfeedBufferForDequeueSymbolName[] = - "__xla_cpu_runtime_AcquireInfeedBufferForDequeue"; -constexpr char kReleaseInfeedBufferAfterDequeueSymbolName[] = - "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue"; +extern const char* const kEigenMatMulF32SymbolName; +extern const char* const kEigenMatMulF64SymbolName; +extern const char* const kEigenConvF32SymbolName; +extern const char* const kEigenSingleThreadedMatMulF32SymbolName; +extern const char* const kEigenSingleThreadedMatMulF64SymbolName; +extern const char* const kEigenSingleThreadedConvF32SymbolName; +extern const char* const kAcquireInfeedBufferForDequeueSymbolName; +extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; +extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; +extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; + +// All symbol names for XLA CPU runtime functions need to start with this +// prefix. +extern const char* const kXlaCpuRuntimeSymbolNamePrefix; // Returns the infeed manager used by the CPU runtime. -InfeedManager* GetInfeedManager(); +XfeedManager* GetXfeedManager(); } // namespace runtime } // namespace cpu @@ -64,13 +65,19 @@ InfeedManager* GetInfeedManager(); extern "C" { +// Note: in the runtime entry points below, the shape pointer and shape_length +// reflect values that can be deserialized via +// llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified +// type information from the generated program to the runtime, which helps check +// the type safety and contract for the emitted-code/runtime communication. + // Blocks until the next infeed buffer is ready to be dequeued, then // returns it. Fails catastrophically if the next enqueued buffer is // not of the correct length in bytes. Checking the shape rather than // the length would be more exact, but the length check is chosen as a // tradeoff between error checking and speed/simplicity. extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - xla::int32 buffer_length); + xla::int32 buffer_length, const void* shape, xla::int32 shape_length); // Relinquishes the next infeed buffer that was returned by // __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call @@ -85,7 +92,27 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( // implemented we will add support for multiple outstanding buffers // that can be returned out of order. extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr); -} + xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, + xla::int32 shape_length); + +// Blocks until the next outfeed buffer is available to be populated, then +// returns it. +extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length); + +// Relinquishes the outfeed buffer after it has been populated. +// buffer_ptr must have been previously returned by +// __xla_cpu_runtime_AcquireOutfeedBufferForPopulation. +// Once this call completes, buffer_ptr may no longer be accessed. +// buffer_length must match the length passed to the call to +// __xla_cpu_runtime_AcquireInfeedBufferForDequeue that returned +// buffer_ptr. This function must be called before the next buffer is +// acquired, i.e., there may only be one outstanding outfeed buffer in +// use by the runtime. +extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, + xla::int32 shape_length); + +} // extern "C" #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc index 646254887c83fcaff8fd5def9fafc8ff17d03d32..f6664bb854e2dda4c199d6f716e6dc7173447cea 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc @@ -19,17 +19,30 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" +#ifdef __AVX__ +xla::cpu::runtime::V8F32 __xla_cpu_runtime_ExpV8F32( + xla::cpu::runtime::V8F32 x) { + return Eigen::internal::pexp(x); +} + +xla::cpu::runtime::V8F32 __xla_cpu_runtime_LogV8F32( + xla::cpu::runtime::V8F32 x) { + return Eigen::internal::plog(x); +} + +xla::cpu::runtime::V8F32 __xla_cpu_runtime_TanhV8F32( + xla::cpu::runtime::V8F32 x) { + return Eigen::internal::ptanh(x); +} +#endif // __AVX__ + namespace xla { namespace cpu { namespace runtime { -#ifdef __AVX__ -V8F32 ExpV8F32(V8F32 x) { return Eigen::internal::pexp(x); } - -V8F32 LogV8F32(V8F32 x) { return Eigen::internal::plog(x); } - -V8F32 TanhV8F32(V8F32 x) { return Eigen::internal::ptanh(x); } -#endif // __AVX__ +const char *const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32"; +const char *const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32"; +const char *const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h index 89721aaf835eec5e4a8be0fbabb310b084065825..c15710fb00197d41c1047d3e8ade0165f18cf0fb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h @@ -28,23 +28,28 @@ namespace xla { namespace cpu { namespace runtime { -constexpr char kExpV8F32[] = "__xla_cpu_runtime_ExpV8F32"; -constexpr char kLogV8F32[] = "__xla_cpu_runtime_LogV8F32"; -constexpr char kTanhV8F32[] = "__xla_cpu_runtime_TanhV8F32"; +extern const char *const kExpV8F32SymbolName; +extern const char *const kLogV8F32SymbolName; +extern const char *const kTanhV8F32SymbolName; typedef float V8F32 __attribute__((__vector_size__(32))); +} // namespace runtime +} // namespace cpu +} // namespace xla + +extern "C" { // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. -V8F32 ExpV8F32(V8F32 x) TF_ATTRIBUTE_WEAK; +xla::cpu::runtime::V8F32 __xla_cpu_runtime_ExpV8F32(xla::cpu::runtime::V8F32 x) + TF_ATTRIBUTE_WEAK; -V8F32 LogV8F32(V8F32 x) TF_ATTRIBUTE_WEAK; +xla::cpu::runtime::V8F32 __xla_cpu_runtime_LogV8F32(xla::cpu::runtime::V8F32 x) + TF_ATTRIBUTE_WEAK; -V8F32 TanhV8F32(V8F32 x) TF_ATTRIBUTE_WEAK; - -} // namespace runtime -} // namespace cpu -} // namespace xla +xla::cpu::runtime::V8F32 __xla_cpu_runtime_TanhV8F32(xla::cpu::runtime::V8F32 x) + TF_ATTRIBUTE_WEAK; +} #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc index 69d04427c60b0d8db8a8859b4abff9bfa7e93260..58ec9fc6e8ee7329c5dc1624cca2f0f0f4b68f59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc @@ -19,29 +19,36 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" -namespace xla { -namespace cpu { -namespace runtime { - #ifdef __SSE4_1__ -V4F32 ExpV4F32(V4F32 x) { +xla::cpu::runtime::V4F32 __xla_cpu_runtime_ExpV4F32( + xla::cpu::runtime::V4F32 x) { Eigen::internal::Packet4f p = x; return Eigen::internal::pexp(p); } -V4F32 LogV4F32(V4F32 x) { +xla::cpu::runtime::V4F32 __xla_cpu_runtime_LogV4F32( + xla::cpu::runtime::V4F32 x) { Eigen::internal::Packet4f p = x; return Eigen::internal::plog(p); } -V4F32 TanhV4F32(V4F32 x) { +xla::cpu::runtime::V4F32 __xla_cpu_runtime_TanhV4F32( + xla::cpu::runtime::V4F32 x) { Eigen::internal::Packet4f p = x; return Eigen::internal::ptanh(p); } #endif // __SSE4_1__ +namespace xla { +namespace cpu { +namespace runtime { + +const char *const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32"; +const char *const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32"; +const char *const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32"; + } // namespace runtime } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h index ded206f90a076ba81643799c07e3f3a7d481eaf2..7ab9a52d00848891b73415fdb5cb49c515243c05 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h @@ -28,23 +28,29 @@ namespace xla { namespace cpu { namespace runtime { -constexpr char kExpV4F32[] = "__xla_cpu_runtime_ExpV4F32"; -constexpr char kLogV4F32[] = "__xla_cpu_runtime_LogV4F32"; -constexpr char kTanhV4F32[] = "__xla_cpu_runtime_TanhV4F32"; +extern const char *const kExpV4F32SymbolName; +extern const char *const kLogV4F32SymbolName; +extern const char *const kTanhV4F32SymbolName; typedef float V4F32 __attribute__((__vector_size__(16))); +} // namespace runtime +} // namespace cpu +} // namespace xla + +extern "C" { + // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. -V4F32 ExpV4F32(V4F32 x) TF_ATTRIBUTE_WEAK; +xla::cpu::runtime::V4F32 __xla_cpu_runtime_ExpV4F32(xla::cpu::runtime::V4F32 x) + TF_ATTRIBUTE_WEAK; -V4F32 LogV4F32(V4F32 x) TF_ATTRIBUTE_WEAK; +xla::cpu::runtime::V4F32 __xla_cpu_runtime_LogV4F32(xla::cpu::runtime::V4F32 x) + TF_ATTRIBUTE_WEAK; -V4F32 TanhV4F32(V4F32 x) TF_ATTRIBUTE_WEAK; - -} // namespace runtime -} // namespace cpu -} // namespace xla +xla::cpu::runtime::V4F32 __xla_cpu_runtime_TanhV4F32(xla::cpu::runtime::V4F32 x) + TF_ATTRIBUTE_WEAK; +} #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 420f9cebc5b1ded365c20079589ebc79a03b3164..f45c28ef74c7ef716e7f0330a1c10abc528a90ee 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -22,9 +22,9 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Instructions.h" #include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/IR/Value.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +44,8 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config) : dot_(dot), transpose_lhs_(transpose_lhs), transpose_rhs_(transpose_rhs), @@ -52,18 +53,20 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, lhs_array_(lhs_array), rhs_array_(rhs_array), executable_run_options_value_(executable_run_options_value), - ir_builder_(ir_builder) {} + ir_builder_(ir_builder), + hlo_module_config_(hlo_module_config) {} /* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder) { + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F32 == type || F64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, lhs_array, rhs_array, executable_run_options_value, - ir_builder); + ir_builder, hlo_module_config); return dot_emitter.Emit(); } @@ -233,22 +236,22 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - bool multi_threaded = flags->xla_cpu_multi_thread_eigen; + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; const char* fn_name; switch (type) { case F32: - fn_name = multi_threaded - ? runtime::kEigenMatmulF32SymbolName - : runtime::kEigenSingleThreadedMatmulF32SymbolName; + fn_name = multi_threaded_eigen + ? runtime::kEigenMatMulF32SymbolName + : runtime::kEigenSingleThreadedMatMulF32SymbolName; float_type = ir_builder_->getFloatTy(); break; case F64: - fn_name = multi_threaded - ? runtime::kEigenMatmulF64SymbolName - : runtime::kEigenSingleThreadedMatmulF64SymbolName; + fn_name = multi_threaded_eigen + ? runtime::kEigenMatMulF64SymbolName + : runtime::kEigenSingleThreadedMatMulF64SymbolName; float_type = ir_builder_->getDoubleTy(); break; default: diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 44dfe5f2a91222d99907e31062fb1d8f74aed3ff..b6147163802dde12a8bf7dde91ac8dad45ba1990 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "external/llvm/include/llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/types.h" @@ -39,7 +40,8 @@ class DotOpEmitter { const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder); + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config); private: DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, @@ -47,7 +49,8 @@ class DotOpEmitter { const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config); // Emits the IR to perform the dot operation. tensorflow::Status Emit(); @@ -82,6 +85,7 @@ class DotOpEmitter { const llvm_ir::IrArray& rhs_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* ir_builder_; + const HloModuleConfig& hlo_module_config_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc deleted file mode 100644 index c65d8216606a1caa561adea5a83c8f1aa2c82906..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" - -#include - -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class InfeedManagerTest : public ::testing::Test {}; - -class TestInfeedBuffer : public cpu::runtime::InfeedBuffer { - public: - explicit TestInfeedBuffer(int32 length) - : done_called_(false), length_(length) {} - ~TestInfeedBuffer() override { EXPECT_TRUE(done_called_); } - - int32 length() override { return length_; } - void* data() override { return nullptr; } - void Done() override { - CHECK(!done_called_); - done_called_ = true; - } - - private: - bool done_called_; - int32 length_; -}; - -void ProcessNextBuffer(int32 length) { - void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(length); - __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer); -} - -TEST_F(InfeedManagerTest, SingleThreadedSequential) { - TestInfeedBuffer* a = new TestInfeedBuffer(64); - TestInfeedBuffer* b = new TestInfeedBuffer(32); - - cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); - - infeed->EnqueueBuffer(a); - infeed->EnqueueBuffer(b); - ProcessNextBuffer(a->length()); - ProcessNextBuffer(b->length()); -} - -TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { - TestInfeedBuffer* a = new TestInfeedBuffer(64); - TestInfeedBuffer* b = new TestInfeedBuffer(32); - - cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); - - infeed->EnqueueBuffer(a); - ProcessNextBuffer(a->length()); - infeed->EnqueueBuffer(b); - ProcessNextBuffer(b->length()); -} - -TEST_F(InfeedManagerTest, MultiThreaded) { - tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); - - cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); - - const int32 length = 64; - - pool.Schedule([infeed]() { - // Spin for 100 milliseconds - int64 start_micros = tensorflow::Env::Default()->NowMicros(); - while (true) { - int64 end_micros = tensorflow::Env::Default()->NowMicros(); - if ((end_micros - start_micros) >= 100000) { // 100 ms - break; - } - } - TestInfeedBuffer* a = new TestInfeedBuffer(length); - infeed->EnqueueBuffer(a); - }); - - ProcessNextBuffer(length); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 2d855d0eb1e9448707b3916d20803cebf2ebabe4..859329e2c1ddca9dbea14c16b67f63d4803b6acd 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -26,11 +25,6 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution) { - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - if (!flags->xla_cpu_use_eigen) { - return false; - } - // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. @@ -82,11 +76,6 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, } // namespace bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - if (!flags->xla_cpu_use_eigen) { - return false; - } - // For certain types of Dot, we can call Eigen if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 7ad497ff1a27ff083517de6a82a8c4b903800cce..3ee417191d3def7e3f0e44155c6c308378c30b96 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -32,8 +32,9 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Instructions.h" #include "external/llvm/include/llvm/IR/Intrinsics.h" #include "external/llvm/include/llvm/IR/LLVMContext.h" +#include "external/llvm/include/llvm/Target/TargetRegisterInfo.h" +#include "external/llvm/include/llvm/Target/TargetSubtargetInfo.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" @@ -52,9 +53,20 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace { +const char* kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; +bool VectorizedReduceDisabled(const xla::HloModuleConfig& config) { + return config.debug_options().xla_backend_extra_options().count( + kXlaDisableVectorizedReduce); +} +} // namespace namespace xla { @@ -65,14 +77,16 @@ namespace cpu { IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, - const std::unordered_map* hlo_to_profile_idx) + const std::unordered_map* hlo_to_profile_idx, + llvm::TargetMachine* target_machine) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), ir_builder_(llvm_module->getContext()), hlo_to_profile_idx_(hlo_to_profile_idx), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), - hlo_module_config_(hlo_module.config()) { + hlo_module_config_(hlo_module.config()), + target_machine_features_(target_machine) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_enable_fast_math())); @@ -83,7 +97,14 @@ StatusOr IrEmitter::EmitComputation( bool is_entry_computation, std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); - VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; + VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix + << "]; ordered? " << (instruction_order != nullptr); + num_dynamic_loop_bounds_ = 0; + if (!computation->root_instruction()->outer_dimension_partitions().empty()) { + num_dynamic_loop_bounds_ = + computation->root_instruction()->outer_dimension_partitions().size(); + } + InitializeIrFunction(function_name, is_entry_computation); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -91,11 +112,10 @@ StatusOr IrEmitter::EmitComputation( arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(is_entry_computation, use_rdtscp, GetProfileCountersArgument()); - if (instruction_order != nullptr) { - TF_RETURN_IF_ERROR(computation->root_instruction()->AcceptOrdered( - this, *instruction_order)); + if (instruction_order == nullptr) { + TF_RETURN_IF_ERROR(computation->Accept(this)); } else { - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); } InsertOrDie(&emitted_functions_, computation, compute_function_); @@ -112,7 +132,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name, bool is_entry_computation) { // The function signature is: // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* prof_counters) + // i64* dynamic_loop_bounds, i64* prof_counters) // // retval: points to the returned value. // params: address of an array with pointers to parameters. @@ -152,6 +172,10 @@ void IrEmitter::InitializeIrFunction(const string& function_name, // | temp 0 | | temp 1 | | temp N-1 | // \---------/ \---------/ \-----------/ // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // // /---------------------------------------------\ // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | // (elided for aot) \---------------------------------------------/ @@ -164,6 +188,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name, llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); std::vector compute_function_params( {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds_ > 0) { + compute_function_params.push_back(i64_ptr_type); + } if (hlo_to_profile_idx_) { compute_function_params.push_back(i64_ptr_type); } @@ -190,6 +217,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name, (++arg_iter)->setName("run_options"); (++arg_iter)->setName("params"); (++arg_iter)->setName("temps"); + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + } if (hlo_to_profile_idx_) { (++arg_iter)->setName("prof_counters"); } @@ -242,12 +272,12 @@ Status IrEmitter::HandleConstant(HloInstruction* constant, return Status::OK(); } -Status IrEmitter::HandleCopy(HloInstruction* copy, HloInstruction* operand) { +Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); emitted_value_[copy] = copy_value; - return EmitMemcpy(*operand, *copy); + return EmitMemcpy(*(copy->operand(0)), *copy); } else { // Use the elemental emitter for non-tuple shapes. return DefaultAction(copy); @@ -358,63 +388,158 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, Status IrEmitter::HandleInfeed(HloInstruction* infeed) { VLOG(2) << "HandleInfeed: " << infeed->ToString(); + const Shape& shape = infeed->shape(); + + // The infeed operation produces data (dequeued from the infeed queue) at this + // address, which has been provided by buffer assignment. + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(infeed)); + + if (ShapeUtil::IsTuple(shape)) { + TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); + + // For a tuple, we first copy each of the internal elements to + // their corresponding target locations. We then construct the + // tuple outer buffer containing pointers to the internal + // elements. + std::vector tuple_element_addresses; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, + assignment_.GetUniqueSlice(infeed, {i})); + + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(shape, i); + + // Only the outer tuple buffer's target address is obtained from + // EmitTargetAddressForOp to handle the case when Infeed is the + // root instruction. Target addresses for internal elements can + // be obtained from EmitTempBufferPointer. + llvm::Value* tuple_element_address = + EmitTempBufferPointer(buffer, tuple_element_shape); + + TF_RETURN_IF_ERROR(EmitXfeedTransfer( + XfeedKind::kInfeed, tuple_element_shape, tuple_element_address)); + + tuple_element_addresses.push_back(tuple_element_address); + } + + llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape), + tuple_element_addresses, &ir_builder_); + } else { + TF_RETURN_IF_ERROR( + EmitXfeedTransfer(XfeedKind::kInfeed, shape, target_address)); + } + + emitted_value_[infeed] = target_address; + + return Status::OK(); +} + +Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, + llvm::Value* program_buffer_address) { + int64 length = ByteSizeOf(shape); + if (length <= 0 || length > std::numeric_limits::max()) { + return InvalidArgument( + "xfeed (infeed or outfeed) buffer length %lld is outside the valid " + "size range", + length); + } + int32 length_32 = static_cast(length); + + int32 shape_length; + TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr, + llvm_ir::EncodeSelfDescribingShapeConstant( + shape, &shape_length, &ir_builder_)); + // The signature of the acquire infeed buffer function is: // // (void*)(int32 length); - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); llvm::Type* int32_type = ir_builder_.getInt32Ty(); - llvm::FunctionType* acquire_type = - llvm::FunctionType::get(i8_ptr_type, {int32_type}, - /*isVarArg=*/false); + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::FunctionType* acquire_type = llvm::FunctionType::get( + i8_ptr_type, {int32_type, i8_ptr_type, int32_type}, + /*isVarArg=*/false); - llvm::Function* acquire_func = - llvm::cast(module_->getOrInsertFunction( - runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); + llvm::Function* acquire_func; + if (kind == XfeedKind::kInfeed) { + acquire_func = llvm::cast(module_->getOrInsertFunction( + runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); + } else { + acquire_func = llvm::cast(module_->getOrInsertFunction( + runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type)); + } acquire_func->setCallingConv(llvm::CallingConv::C); // The signature of the release infeed buffer function is: // // (void)(int32 length, void* buffer); llvm::FunctionType* release_type = llvm::FunctionType::get( - ir_builder_.getVoidTy(), {int32_type, i8_ptr_type}, + ir_builder_.getVoidTy(), + {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, /*isVarArg=*/false); - llvm::Function* release_func = - llvm::cast(module_->getOrInsertFunction( - runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); + llvm::Function* release_func; + if (kind == XfeedKind::kInfeed) { + release_func = llvm::cast(module_->getOrInsertFunction( + runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); + } else { + release_func = llvm::cast(module_->getOrInsertFunction( + runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type)); + } release_func->setCallingConv(llvm::CallingConv::C); - const Shape& shape = infeed->shape(); - int64 length = ByteSizeOf(shape); - if (length > std::numeric_limits::max()) { - return InvalidArgument("infeed buffer length %lld is too large", length); + // Implementation note: this call informs the runtime that it wants a buffer + // of size exactly 'length_32', and the runtime is responsible for + // check-failing the process if there is a mismatch, versus passing us back a + // buffer that we might overrun. + llvm::Value* acquired_pointer = ir_builder_.CreateCall( + acquire_func, {ir_builder_.getInt32(length_32), shape_ptr, + ir_builder_.getInt32(shape_length)}); + + if (kind == XfeedKind::kInfeed) { + // Copy to the program buffer address from the acquired buffer. + ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer, + length_32, 1); + } else { + // Outfeed -- copy from the in-program address to the acquired buffer. + ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address, + length_32, 1); } - int32 length_32 = static_cast(length); - - llvm::Value* acquired_pointer = - ir_builder_.CreateCall(acquire_func, {ir_builder_.getInt32(length_32)}); - - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(infeed)); - - ir_builder_.CreateMemCpy(target_address, acquired_pointer, length_32, 1); ir_builder_.CreateCall(release_func, - {ir_builder_.getInt32(length_32), acquired_pointer}); - - emitted_value_[infeed] = target_address; + {ir_builder_.getInt32(length_32), acquired_pointer, + shape_ptr, ir_builder_.getInt32(shape_length)}); return Status::OK(); } Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { - // TODO(b/34359662): Implement outfeed on CPU. - return Unimplemented("Outfeed is not supported on CPU (b/34359662)."); + HloInstruction* operand = outfeed->operands()[0]; + const Shape& operand_shape = operand->shape(); + + llvm::Value* value = GetEmittedValueFor(operand); + if (!ShapeUtil::IsTuple(operand_shape)) { + return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value); + } + + TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape)); + + for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) { + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(operand_shape, i); + llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( + tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), + value, &ir_builder_); + TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, + tuple_element_shape, tuple_element)); + } + + return Status::OK(); } Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) { // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not supported on GPU (b/26783907)."); + return Unimplemented("Sort is not supported on CPU (b/26783907)."); } Status IrEmitter::HandleTuple( @@ -760,7 +885,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, // Dot operation is complicated so we delegate to a helper class. TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_)); + lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, + hlo_module_config_)); emitted_value_[dot] = target_address; return Status::OK(); @@ -845,9 +971,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); const char* fn_name = - (flags->xla_cpu_multi_thread_eigen + (multi_threaded_eigen ? runtime::kEigenConvF32SymbolName : runtime::kEigenSingleThreadedConvF32SymbolName); llvm::Function* conv_func = llvm::cast( @@ -1039,6 +1166,237 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { "Cross replica sum not implemented on CPU. See b/33011107."); } +// Fills up the free variables in 'index_with_free_var' with values from +// 'filler_index'. The size of free variables must be the same as the +// size of 'filler_index'. +// +// This is often used after dimension reduction, where +// 'index_with_free_var' has one or more dimensions reduced, which serves as +// free variables (represented as nullptr). For example, if we have a 4 +// dimensional input and index for the dimension being reduced is +// 2 (third dimension), we will have an index like [i, j, NULL, k] +// after reduced dimension. +// +// Here we fill up that free variable by 'filler_index', which contains +// the value in the reduced dimension. +static llvm_ir::IrArray::Index FillReducedDimensionIndex( + llvm_ir::IrArray::Index index_with_free_var, + llvm_ir::IrArray::Index filler_index) { + llvm_ir::IrArray::Index::const_iterator it = filler_index.begin(); + + for (size_t i = 0; i < index_with_free_var.size(); ++i) { + if (index_with_free_var[i] == nullptr) { + index_with_free_var[i] = *it++; + } + } + CHECK(filler_index.end() == it); + return index_with_free_var; +} + +Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { + // The output of BatchNormTraining is a tuple of three element: + // - An N-dimensional array containing normalized values. + // - A 1 dimensional array containing the mean value for each feature. + // - A 1 dimensional array containing the variance value for each feature. + HloInstruction* operand = batch_norm_training->operands()[0]; + HloInstruction* scale = batch_norm_training->operands()[1]; + HloInstruction* offset = batch_norm_training->operands()[2]; + float epsilon = batch_norm_training->epsilon(); + int64 feature_index = batch_norm_training->feature_index(); + TF_RET_CHECK(ShapeUtil::IsTuple(batch_norm_training->shape()) && + ShapeUtil::TupleElementCount(batch_norm_training->shape()) == 3); + + const Shape& output_shape = + ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 0); + const Shape& feature_shape = + ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 1); + + // Reduce vector of the non-feature dimensions. + std::vector dimensions_to_reduce; + + for (int64 i = 0; i < operand->shape().dimensions_size(); ++i) { + if (i != feature_index) { + dimensions_to_reduce.push_back(i); + } + } + + // Get the second and third allocations in the output tuple, which should be + // used to store the result of mean and variance value calculation. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice_mean, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{1})); + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice_var, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{2})); + const int feature_count = output_shape.dimensions(feature_index); + const int size_in_elements = ShapeUtil::ElementsIn(output_shape); + TF_RET_CHECK(ShapeUtil::ElementsIn(operand->shape()) == size_in_elements); + const int elements_per_feature = size_in_elements / feature_count; + + llvm::Value* mean = EmitTempBufferPointer(slice_mean, feature_shape); + llvm_ir::IrArray mean_array(mean, feature_shape); + + llvm::Value* var = EmitTempBufferPointer(slice_var, feature_shape); + llvm_ir::IrArray var_array(var, feature_shape); + + // This loop calculates mean and variance for each feature. + // + // In theory this could be swapped by multi-output fusion. We will evaluate + // this when it's ready. + // + // For variance calculation, we use a simplified formula so we can fuse the + // computation into the same loop to calculate mean: Var=E(X^2) - E(X)^2. + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter( + [this, operand, dimensions_to_reduce, feature_shape, var_array, + elements_per_feature](const llvm_ir::IrArray::Index& index) { + PrimitiveType element_type = operand->shape().element_type(); + // Used to calculate E(X). + llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + "sum_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(element_type)); + + // Used to calculate E(X^2). + llvm::Value* sum_square_address = + llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + "sum_square_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(element_type)); + + ir_builder_.CreateStore( + llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), + sum_address); + + ir_builder_.CreateStore( + llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), + sum_square_address); + + llvm_ir::ForLoopNest loops(&ir_builder_); + + const llvm_ir::IrArray::Index reduced_dims_index = + loops.AddLoopsForShapeOnDimensions( + operand->shape(), dimensions_to_reduce, "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), + &ir_builder_); + + llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray::Index input_index = + FillReducedDimensionIndex(reduced_dims_index, index); + llvm::Value* new_value = + operand_array.EmitReadArrayElement(input_index, &ir_builder_); + + llvm::Value* new_value_square = + ir_builder_.CreateFMul(new_value, new_value); + + llvm::Value* current_sum = ir_builder_.CreateLoad(sum_address); + llvm::Value* current_sum_square = + ir_builder_.CreateLoad(sum_square_address); + // Update sum. + ir_builder_.CreateStore( + ir_builder_.CreateFAdd(current_sum, new_value), sum_address); + + // Update sum square. + ir_builder_.CreateStore( + ir_builder_.CreateFAdd(current_sum_square, new_value_square), + sum_square_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), + &ir_builder_); + + llvm::Value* sum = ir_builder_.CreateLoad(sum_address); + llvm::Value* elements_per_feature_value = llvm::ConstantFP::get( + ir_builder_.getFloatTy(), elements_per_feature); + llvm::Value* mean = + ir_builder_.CreateFDiv(sum, elements_per_feature_value); + llvm::Value* mean_square = ir_builder_.CreateFMul(mean, mean); + llvm::Value* sum_square = + ir_builder_.CreateLoad(sum_square_address); + + // Var=E(X^2) - E(X)^2. + llvm::Value* var = ir_builder_.CreateFSub( + ir_builder_.CreateFDiv(sum_square, elements_per_feature_value), + mean_square); + + var_array.EmitWriteArrayElement(index, var, &ir_builder_); + return mean; + }, + mean_array, &ir_builder_) + .EmitLoop()); + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(batch_norm_training)); + + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); + + llvm::Value* normalized = EmitTempBufferPointer(slice, output_shape); + + llvm_ir::IrArray target_array(normalized, output_shape); + + AddAliasingInformationToIrArray(*batch_norm_training, &target_array); + + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter( + [this, mean_array, var_array, epsilon, operand, dimensions_to_reduce, + feature_index, offset, scale](const llvm_ir::IrArray::Index& index) { + // The following logic normalizes the input value, scales and shifts + // it: + // + // normalized = (input - mean) / sqrt(variance + epsilon) + // result = normalized * scale + offset + + // Current index in the feature dimension. + llvm_ir::IrArray::Index feature_index_value(1, + index[feature_index]); + + llvm::Value* mean = mean_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm::Value* var = var_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + + llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm::Value* input = + operand_array.EmitReadArrayElement(index, &ir_builder_); + + llvm::Value* variance_with_epsilon = ir_builder_.CreateFAdd( + var, llvm::ConstantFP::get(ir_builder_.getFloatTy(), epsilon)); + llvm::Function* func_llvm_sqrt = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {ir_builder_.getFloatTy()}); + llvm::Value* variance_sqrt = + ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); + llvm::Value* normalized = ir_builder_.CreateFDiv( + ir_builder_.CreateFSub(input, mean), variance_sqrt); + llvm_ir::IrArray offset_array(GetIrArrayForOp(offset)); + llvm::Value* offset = offset_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm_ir::IrArray scale_array(GetIrArrayForOp(scale)); + llvm::Value* scale = scale_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm::Value* result = ir_builder_.CreateFAdd( + ir_builder_.CreateFMul(normalized, scale), offset); + + return result; + }, + target_array, &ir_builder_) + .EmitLoop()); + + llvm_ir::EmitTuple( + llvm_ir::IrArray(target_address, batch_norm_training->shape()), + {normalized, mean, var}, &ir_builder_); + emitted_value_[batch_norm_training] = target_address; + + return Status::OK(); +} + +Status IrEmitter::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + // TODO(b/62843645) Implement BatchNormGrad on CPU backend. + return Unimplemented( + "BatchNormGrad is not implemented on CPU. See b/62843645."); +} + Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); auto param_number = parameter->parameter_number(); @@ -1073,10 +1431,450 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } +IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( + HloComputation* function, string* failure_reason) const { + CHECK_EQ(function->num_parameters(), 2); + + auto root_instruction = function->root_instruction(); + CHECK(ShapeUtil::IsScalar(root_instruction->shape())); + + if (root_instruction->operand_count() != 2) { + *failure_reason = "root instruction is not a binary operation"; + return nullptr; + } + + const Shape& root_shape = root_instruction->shape(); + bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape); + bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape); + bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape); + + auto lhs = root_instruction->operand(0); + auto rhs = root_instruction->operand(1); + + auto param_0 = function->parameter_instruction(0); + auto param_1 = function->parameter_instruction(1); + if (!(lhs == param_0 && rhs == param_1) && + !(rhs == param_0 && lhs == param_1)) { + *failure_reason = + "root instruction is not a binary operation on the incoming arguments"; + return nullptr; + } + + CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape())); + + // This is visually similar to ElementalIrEmitter, though conceptually we're + // doing something different here. ElementalIrEmitter emits scalar operations + // while these emit scalar or vector operations depending on the type of the + // operands. + switch (root_instruction->opcode()) { + default: + *failure_reason = "did not recognize root instruction opcode"; + return nullptr; + + case HloOpcode::kAdd: + return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + llvm::Value* rhs) { + return root_is_integral ? ir_builder->CreateAdd(lhs, rhs) + : ir_builder->CreateFAdd(lhs, rhs); + }; + + case HloOpcode::kMultiply: + return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + llvm::Value* rhs) { + return root_is_integral ? ir_builder->CreateMul(lhs, rhs) + : ir_builder->CreateFMul(lhs, rhs); + }; + + case HloOpcode::kLogicalAnd: + return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); }; + + case HloOpcode::kLogicalOr: + return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; + + case HloOpcode::kMaximum: + return [root_is_floating_point, root_is_signed]( + llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + llvm::Value* rhs) { + if (root_is_floating_point) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, + {lhs, rhs}, {lhs->getType()}, + ir_builder); + } + + return ir_builder->CreateSelect( + ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs, rhs), + lhs, rhs); + }; + + case HloOpcode::kMinimum: + return [root_is_floating_point, root_is_signed]( + llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + llvm::Value* rhs) { + if (root_is_floating_point) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, + {lhs, rhs}, {lhs->getType()}, + ir_builder); + } + + return ir_builder->CreateSelect( + ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs, rhs), + lhs, rhs); + }; + } +} + +IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( + PrimitiveType element_type, unsigned element_count) { + // Here we assume that the largest register is a vector register. + int max_vector_register_size_in_bytes = + target_machine_features_.largest_register_size_in_bytes( + compute_function_); + + int vector_register_size_in_elements = + max_vector_register_size_in_bytes / + ShapeUtil::ByteSizeOfPrimitiveType(element_type); + + ShardedVectorType sharded_vector_type; + llvm::Type* element_ir_type = + llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_); + + for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) { + // For every power of two present in element_count, we generate one or more + // vector or scalar types. + const unsigned current_size_fragment = 1u << i; + if (!(element_count & current_size_fragment)) { + // Power of two not present in element_count. + continue; + } + + if (current_size_fragment == 1) { + // Single element, use a scalar type. + sharded_vector_type.push_back(element_ir_type); + continue; + } + + // Lower "current_size_fragment" number of elements using (as few as + // possible) vector registers. + + if (current_size_fragment >= vector_register_size_in_elements) { + auto vector_type = llvm::VectorType::get( + element_ir_type, vector_register_size_in_elements); + sharded_vector_type.insert( + sharded_vector_type.end(), + current_size_fragment / vector_register_size_in_elements, + vector_type); + + // Both current_size_fragment and vector_register_size_in_elements are + // powers of two. + CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0); + continue; + } + + // For now we assume that vector_register_size_in_elements and lower powers + // of two are all legal vector sizes (or at least can be lowered easily by + // LLVM). + sharded_vector_type.push_back( + llvm::VectorType::get(element_ir_type, current_size_fragment)); + } + return sharded_vector_type; +} + +StatusOr +IrEmitter::EmitInnerLoopForVectorizedReduction( + const ReductionGenerator& reduction_generator, + const llvm_ir::IrArray::Index& output_index, + const ShardedVectorType& accumulator_type, HloInstruction* init_value, + HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, + unsigned element_alignment) { + ShardedVector accumulator; + accumulator.reserve(accumulator_type.size()); + for (auto accumulator_shard_type : accumulator_type) { + accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry( + accumulator_shard_type, "accumulator", &ir_builder_, 0)); + } + + llvm::Value* init_value_ssa = + ir_builder_.CreateLoad(GetEmittedValueFor(init_value)); + + for (llvm::Value* accumulator_shard : accumulator) { + llvm::Value* initial_value; + auto shard_type = accumulator_shard->getType()->getPointerElementType(); + if (auto vector_type = llvm::dyn_cast(shard_type)) { + initial_value = ir_builder_.CreateVectorSplat( + vector_type->getNumElements(), init_value_ssa); + } else { + initial_value = init_value_ssa; + } + + ir_builder_.CreateAlignedStore(initial_value, accumulator_shard, + element_alignment); + } + + llvm_ir::ForLoopNest reduction_loop_nest(&ir_builder_); + llvm_ir::IrArray::Index reduced_dims_index = + reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, + "reduction_dim"); + + SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), + &ir_builder_); + + llvm_ir::IrArray arg_array(GetIrArrayForOp(arg)); + llvm_ir::IrArray::Index input_index = reduced_dims_index; + llvm_ir::IrArray::Index::const_iterator it = output_index.begin(); + + for (size_t i = 0; i < input_index.size(); ++i) { + if (input_index[i] == nullptr) { + input_index[i] = *it++; + } + } + CHECK(output_index.end() == it); + + llvm::Value* input_address = ir_builder_.CreateBitCast( + arg_array.EmitArrayElementAddress(input_index, &ir_builder_), + ir_builder_.getInt8PtrTy()); + + for (int i = 0; i < accumulator.size(); i++) { + auto input_address_typed = + ir_builder_.CreateBitCast(input_address, accumulator[i]->getType()); + auto current_accumulator_value = + ir_builder_.CreateAlignedLoad(accumulator[i], element_alignment); + auto addend = + ir_builder_.CreateAlignedLoad(input_address_typed, element_alignment); + arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); + + auto reduced_result = + reduction_generator(&ir_builder_, current_accumulator_value, addend); + ir_builder_.CreateAlignedStore(reduced_result, accumulator[i], + element_alignment); + + if (i != (accumulator.size() - 1)) { + input_address = ir_builder_.CreateConstInBoundsGEP1_32( + reduced_result->getType(), input_address_typed, 1); + } + } + + SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), + &ir_builder_); + + ShardedVector result_ssa; + result_ssa.reserve(accumulator.size()); + for (auto accumulator_shard : accumulator) { + result_ssa.push_back( + ir_builder_.CreateAlignedLoad(accumulator_shard, element_alignment)); + } + return result_ssa; +} + +void IrEmitter::EmitShardedVectorStore( + llvm::Value* store_address, const std::vector& value_to_store, + const int alignment, const llvm_ir::IrArray& containing_array) { + for (int i = 0; i < value_to_store.size(); i++) { + auto store_address_typed = ir_builder_.CreateBitCast( + store_address, + llvm::PointerType::getUnqual(value_to_store[i]->getType())); + + auto store_instruction = ir_builder_.CreateAlignedStore( + value_to_store[i], store_address_typed, alignment); + containing_array.AnnotateLoadStoreInstructionWithMetadata( + store_instruction); + + if (i != (value_to_store.size() - 1)) { + store_address = ir_builder_.CreateConstInBoundsGEP1_32( + value_to_store[i]->getType(), store_address_typed, 1); + } + } +} + +namespace { +// TODO(sanjoy): This is duplicated in tensorflow/core/lib/core/arena.cc. +// Extract out a common implementation to tensorflow/core/lib/math/math_util.h +uint32 GCD(uint32 x, uint32 y) { + while (y != 0) { + uint32 r = x % y; + x = y; + y = r; + } + return x; +} +} // namespace + +StatusOr IrEmitter::EmitVectorizedReduce( + HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, HloComputation* function, + string* failure_reason) { + ReductionGenerator reduction_generator = + MatchReductionGenerator(function, failure_reason); + if (!reduction_generator) { + return false; + } + + int vectorization_factor_in_bytes = + target_machine_features_.vectorization_factor_in_bytes(); + + // We try to process vectorization_factor elements at the same time. + const int vectorization_factor = + vectorization_factor_in_bytes / + ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); + + bool is_reduction_over_minor_dimension = + std::find(dimensions.begin(), dimensions.end(), + arg->shape().layout().minor_to_major(0)) != dimensions.end(); + + unsigned element_alignment = + GCD(ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), + MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); + + if (is_reduction_over_minor_dimension) { + // TODO(sanjoy): Implement vectorized reduction over the minor dimension. + *failure_reason = "reduction over minor dimension not implemented"; + return false; + } + + CHECK(!ShapeUtil::IsTuple(reduce->shape())); + + // We know we're not reducing over the most minor dimension, which means we + // can lower the reduction loop as: + // + // 1. We're reducing over dimensions R0, R1. + // 2. D0 is the most minor dimension. + // 3. VS is the vectorization stride (we want to reduce this many elements at + // once) + // + // for (d1 in D1) { + // for (d0 in D0 with stride VS) { + // vector_acc = init + // for (r1 in R1) { + // for (r0 in R0) { + // vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0] + // } + // } + // output[d1, d0] = vector_acc + // } + // } + + llvm_ir::ForLoopNest loop_nest(&ir_builder_); + llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); + for (int i = reduce->shape().layout().minor_to_major_size() - 1; i > 0; --i) { + int64 dimension = reduce->shape().layout().minor_to_major(i); + int64 start_index = 0; + int64 end_index = reduce->shape().dimensions(dimension); + std::unique_ptr loop = + loop_nest.AddLoop(start_index, end_index, + tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + + int64 innermost_dimension = reduce->shape().layout().minor_to_major(0); + int64 innermost_dimension_size = + reduce->shape().dimensions(innermost_dimension); + + if (llvm::BasicBlock* innermost_body_bb = + loop_nest.GetInnerLoopBodyBasicBlock()) { + SetToFirstInsertPoint(innermost_body_bb, &ir_builder_); + } + + auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock(); + + if (innermost_dimension_size >= vectorization_factor) { + int64 start_index = 0; + int64 end_index = (innermost_dimension_size / vectorization_factor) * + vectorization_factor; + std::unique_ptr loop = loop_nest.AddLoop( + start_index, end_index, vectorization_factor, + tensorflow::strings::Printf("dim.%lld", innermost_dimension)); + array_index[innermost_dimension] = loop->GetIndVarValue(); + + SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &ir_builder_); + + ShardedVectorType vector_type = CreateShardedVectorType( + reduce->shape().element_type(), vectorization_factor); + TF_ASSIGN_OR_RETURN(std::vector accumulator, + EmitInnerLoopForVectorizedReduction( + reduction_generator, array_index, vector_type, + init_value, arg, dimensions, element_alignment)); + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(reduce)); + llvm_ir::IrArray target_array(target_address, reduce->shape()); + AddAliasingInformationToIrArray(*reduce, &target_array); + llvm::Value* output_address = + target_array.EmitArrayElementAddress(array_index, &ir_builder_); + EmitShardedVectorStore(output_address, accumulator, element_alignment, + target_array); + + if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) { + CHECK_GT(reduce->shape().layout().minor_to_major_size(), 1); + ir_builder_.SetInsertPoint(exit_terminator); + } else { + CHECK_EQ(reduce->shape().layout().minor_to_major_size(), 1); + ir_builder_.SetInsertPoint(loop->GetExitBasicBlock()); + } + } + + // Since we increment the stride for the inner dimension by more than 1, we + // may need to peel out an "epilogue" iteration to get the remaining elements + // in the following case: + if (innermost_dimension_size % vectorization_factor) { + // TODO(b/63775531): Consider using a scalar loop here to save on code size. + array_index[innermost_dimension] = + ir_builder_.getInt64(innermost_dimension_size - + (innermost_dimension_size % vectorization_factor)); + + ShardedVectorType vector_type = CreateShardedVectorType( + reduce->shape().element_type(), + innermost_dimension_size % vectorization_factor); + TF_ASSIGN_OR_RETURN(std::vector accumulator, + EmitInnerLoopForVectorizedReduction( + reduction_generator, array_index, vector_type, + init_value, arg, dimensions, element_alignment)); + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(reduce)); + llvm_ir::IrArray target_array(target_address, reduce->shape()); + AddAliasingInformationToIrArray(*reduce, &target_array); + llvm::Value* output_address = + target_array.EmitArrayElementAddress(array_index, &ir_builder_); + EmitShardedVectorStore(output_address, accumulator, element_alignment, + target_array); + } + + if (outermost_loop_exit_block) { + ir_builder_.SetInsertPoint(outermost_loop_exit_block); + } + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(reduce)); + + emitted_value_[reduce] = target_address; + return true; +} + Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { + if (!VectorizedReduceDisabled(hlo_module_config_)) { + string vectorization_failure_reason; + TF_ASSIGN_OR_RETURN( + bool vectorization_successful, + EmitVectorizedReduce(reduce, arg, init_value, dimensions, function, + &vectorization_failure_reason)); + if (vectorization_successful) { + VLOG(1) << "Successfully vectorized reduction " << reduce->ToString() + << "\n"; + return Status::OK(); + } else { + VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": " + << vectorization_failure_reason; + } + } + // The called computation should have been emitted previously. llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); return EmitTargetElementLoop( @@ -1140,13 +1938,143 @@ Status IrEmitter::HandleSend(HloInstruction* send) { } Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { - if (ShapeUtil::IsScalar(slice->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(slice)); - emitted_value_[slice] = target_address; - return EmitMemcpy(*operand, *slice); + VLOG(2) << "HandleSlice: " << slice->ToString(); + + // The code below emits a sequential loop nest. For the parallel backend, use + // EmitParallelTargetElementLoop() which respects dynamic loop bounds. + if (ShouldEmitParallelLoopFor(*slice)) { + return DefaultAction(slice); + } + + // The code below assumes the layouts are equal. + if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) { + return DefaultAction(slice); + } + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(slice)); + emitted_value_[slice] = target_address; + + if (ShapeUtil::HasZeroElements(slice->shape())) { + return Status::OK(); + } + + const Layout& layout = operand->shape().layout(); + const int64 num_dims = operand->shape().dimensions_size(); + + // The slice lowering finds maximal contiguous blocks of memory that can be + // copied from the source to the target. This is done by looking at the + // source/target layout in minor to major order and do the following: + // + // * Find an initial segment of dimensions along which the slice uses the + // whole dimension. These are the "inner" dimensions and can be folded into + // the memcpy. + // + // * Of the remaining dimensions decide which ones require loops. + // + // * Implement the memcpy within the innermost loop. + + tensorflow::gtl::FlatSet inner_dims; + for (int64 dim : layout.minor_to_major()) { + if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { + break; + } + inner_dims.insert(dim); + } + + const bool is_trivial_copy = (inner_dims.size() == num_dims); + if (is_trivial_copy) { + if (ShapeUtil::IsEffectiveScalar(slice->shape())) { + return DefaultAction(slice); + } else { + return EmitMemcpy(*slice, *operand); + } } - return DefaultAction(slice); + + // The memcpy will copy elements that are logically this shape (allowed to be + // scalar). + const Shape element_shape = ShapeUtil::FilterDimensions( + [&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); }, + operand->shape()); + + // memcpy_dim is the innermost (in terms of layout) dimension for which the + // slice does *not* just copy all the elements along the dimension. + const int64 memcpy_dim = layout.minor_to_major(inner_dims.size()); + + const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1; + // The number of logical elements that can be copied in a single call + // to memcpy. We can only copy 1 element at a time if there is a non-trivial + // stride. + const int64 memcpy_elements = + memcpy_is_contiguous + ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim) + : 1; + + if (memcpy_elements == 1 && ShapeUtil::IsEffectiveScalar(element_shape)) { + // Avoid using memcpy for copying element by element at a time. This does + // not buy us anything and may actually cause LLVM's load/store optimization + // to be less effective. + return DefaultAction(slice); + } + + // Determine the dimensions that get lowered as loops. + std::vector outer_dims; + for (int64 i = 0; i < num_dims - inner_dims.size() - 1; ++i) { + outer_dims.push_back(LayoutUtil::Major(layout, i)); + } + + // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim + // needs to be wrapped around a loop as well. + if (!memcpy_is_contiguous) { + outer_dims.push_back(memcpy_dim); + } + + llvm_ir::IrArray target_array(target_address, slice->shape()); + AddAliasingInformationToIrArray(*slice, &target_array); + + const int64 num_outer_loops = outer_dims.size(); + llvm_ir::ForLoopNest loops(&ir_builder_); + llvm_ir::IrArray::Index target_index = + loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice"); + + // Only the indices for the outer dimensions have been initialized in + // target_index. The rest of the indices should get initialized to 0, since + // for the rest of the dimensions the copy writes to the full dimension. + for (llvm::Value*& index : target_index) { + if (index == nullptr) { + index = ir_builder_.getInt64(0); + } + } + + if (num_outer_loops > 0) { + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + } + + llvm_ir::IrArray source_array(GetEmittedValueFor(operand), operand->shape()); + + const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( + /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), + /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_); + + llvm::Value* memcpy_dest = target_array.EmitArrayElementAddress( + target_index, &ir_builder_, "slice.dest"); + llvm::Value* memcpy_source = source_array.EmitArrayElementAddress( + source_index, &ir_builder_, "slice.source"); + const int64 memcpy_bytes = + ShapeUtil::ByteSizeOf(element_shape) * memcpy_elements; + // TODO(b/63762267): Be more aggressive with `align` by using the GCD of the + // element size and buffer alignment. + ir_builder_.CreateMemCpy(memcpy_dest, memcpy_source, memcpy_bytes, + /*align=*/1); + + VLOG(2) << " emitted memcpy of " << memcpy_bytes << " bytes inside " + << num_outer_loops << " loops"; + + if (num_outer_loops > 0) { + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + } + + return Status::OK(); } Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice, @@ -1283,7 +2211,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *dot, dot->operand(0)->IsRank2Transpose(), dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array, - GetExecutableRunOptionsArgument(), &ir_builder_)); + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_)); emitted_value_[fusion] = target_address; return Status::OK(); @@ -1568,6 +2496,7 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( } Status IrEmitter::Preprocess(HloInstruction* hlo) { + VLOG(3) << "Visiting: " << hlo->ToString(); if (hlo_to_profile_idx_ && hlo_to_profile_idx_->count(hlo)) { profiling_state_.RecordCycleStart(&ir_builder_, hlo); } @@ -1606,13 +2535,24 @@ llvm::Argument* IrEmitter::GetResultArgument() { } llvm::Argument* IrEmitter::GetProfileCountersArgument() { - return hlo_to_profile_idx_ ? GetArg(compute_function_, 4) : nullptr; + const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; + return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; } llvm::Value* IrEmitter::GetTempBuffersArgument() { return GetArg(compute_function_, 3); } +llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_.CreateLoad( + ir_builder_.CreateGEP(loop_bounds_arg, ir_builder_.getInt64(offset), + llvm_ir::AsStringRef(name))); +} + llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return GetArg(compute_function_, 1); } @@ -1645,11 +2585,14 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( GetTempBuffersArgument(), slice.index(), &ir_builder_); llvm::LoadInst* tempbuf_address_base = ir_builder_.CreateLoad(tempbuf_address_ptr); - // Loading the address of a buffer is invariant of the point at which the - // load is executed in the program because we never reassign buffers. - tempbuf_address_base->setMetadata( - llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); + if (hlo_module_config_.debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // Loading the address of a buffer is invariant of the point at which the + // load is executed in the program because we never reassign buffers. + tempbuf_address_base->setMetadata( + llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); + } llvm_ir::SetTbaaForInstruction(tempbuf_address_base, target_shape, /*is_pointer_to=*/true); AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); @@ -1739,13 +2682,13 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( } StatusOr IrEmitter::EmitTargetAddressForOp( - const HloInstruction* op) { - const Shape& target_shape = op->shape(); - if (op == op->parent()->root_instruction()) { + const HloInstruction* op, const ShapeIndex& shape_index) { + const Shape& target_shape = ShapeUtil::GetSubshape(op->shape(), shape_index); + if (op == op->parent()->root_instruction() && shape_index.empty()) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = GetResultArgument(); - if (!ShapeUtil::HasZeroElements(target_shape)) { + if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); @@ -1773,16 +2716,103 @@ Status IrEmitter::EmitTargetElementLoop( TF_ASSIGN_OR_RETURN(llvm::Value * target_address, EmitTargetAddressForOp(target_op)); VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address); - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*target_op, &target_array); - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) - .EmitLoop()); + if (target_op->IsMultiOutputFusion()) { + // For multiple outputs fusion, we need to emit each operand and the root. + TF_RET_CHECK(num_dynamic_loop_bounds_ == 0); + std::vector output_arrays; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(target_op, {i})); + const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); + llvm::Value* op_target_address = + EmitTempBufferPointer(slice, element_shape); + output_arrays.push_back( + llvm_ir::IrArray(op_target_address, element_shape)); + } + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, output_arrays, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (int64 i = 0; i < output_arrays.size(); ++i) { + tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); + } + llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, target_shape), + tuple_operand_ptrs, &ir_builder_); + + } else { + llvm_ir::IrArray target_array(target_address, target_shape); + AddAliasingInformationToIrArray(*target_op, &target_array); + + if (ShouldEmitParallelLoopFor(*target_op)) { + TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( + target_shape, element_generator, &target_array)); + } else { + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) + .EmitLoop()); + } + } + emitted_value_[target_op] = target_address; return Status::OK(); } +Status IrEmitter::EmitParallelTargetElementLoop( + const Shape& target_shape, + const llvm_ir::ElementGenerator& element_generator, + llvm_ir::IrArray* target_array) { + CHECK(!ShapeUtil::IsTuple(target_shape)); + CHECK(!ShapeUtil::IsScalar(target_shape)); + + // Emit code to read dynamic loop bounds from function argument 4. + std::vector dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); + for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i] = GetDynamicLoopBound(i); + } + + llvm_ir::ForLoopNest loop_nest(&ir_builder_); + const int64 num_dims = target_shape.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { + const int64 dimension = target_shape.layout().minor_to_major(i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < num_dynamic_loop_bounds_) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; + llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; + + std::unique_ptr loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/target_shape.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); + + // Emit loop body. + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + element_generator(array_index)); + target_array->EmitWriteArrayElement(array_index, target_element, + &ir_builder_); + // Point IR builder at outer loop exit BB. + SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); + + return Status::OK(); +} + Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); @@ -1825,5 +2855,36 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } +unsigned TargetMachineFeatures::largest_register_size_in_bytes( + llvm::Function* function) { + auto itr = largest_register_size_in_bytes_.find(function); + if (itr != largest_register_size_in_bytes_.end()) { + return itr->second; + } + + int result = largest_register_size_in_bytes_impl(function); + + InsertOrDie(&largest_register_size_in_bytes_, function, result); + DCHECK_EQ(result, largest_register_size_in_bytes_.begin()->second); + return result; +} + +unsigned TargetMachineFeatures::largest_register_size_in_bytes_impl( + llvm::Function* function) const { + auto register_info = + target_machine_->getSubtargetImpl(*function)->getRegisterInfo(); + + unsigned largest_register_size = 0; + for (const llvm::TargetRegisterClass* register_class : + register_info->regclasses()) { + if (register_class->isAllocatable()) { + largest_register_size = + std::max(largest_register_size, + register_info->getRegSizeInBits(*register_class)); + } + } + + return largest_register_size / 8; +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index ebb7296a075f266870fa179a0791dd6d0f77e29f..1a77f695809d471f5c7d03d01ec291093326a9ef 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -41,12 +41,55 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace cpu { +// Wraps an llvm::TargetMachine and parses out some information that feeds into +// code LLVM IR generation decisions. +// +// Ideally we'd be able to use llvm::TargetTransformInfo here (since its +// interface is pretty much a perfect fit for our use case), but obtaining an +// instance of llvm::TargetTransformInfo outside an LLVM pass pipeline without +// super-ugly hacks is difficult. +// +// TODO(b/27457097): See if the LLVM community will be receptive to exposing an +// API that lets us directly create and use llvm::TargetTransformInfo instances +// outside of a pass manager. +class TargetMachineFeatures { + public: + TargetMachineFeatures(llvm::TargetMachine* target_machine) + : target_machine_(target_machine) {} + + // Return the vectorization factor, which is the number of bytes of data + // explicitly vectorized routines will try to process at once. + int vectorization_factor_in_bytes() const { + // Ideally this should be a function of the cache line size (which we can + // get from llvm::TargetTransformInfo::getCacheLineSize) of the target + // machine. Guess a value of 128 bytes for now. + return 128; + } + + // Return the size of the largest register size in bytes. We need to pass in + // "function" since llvm functions can contain annotations for specializing + // them to specific micro-architectures (though currently XLA does not use + // this functionality). + // + // Ideally we should have been able to use + // llvm::TargetTransformInfo::getRegisterBitWidth(true) here. + unsigned largest_register_size_in_bytes(llvm::Function* function); + + private: + unsigned largest_register_size_in_bytes_impl(llvm::Function* function) const; + + tensorflow::gtl::FlatMap + largest_register_size_in_bytes_; + llvm::TargetMachine* target_machine_; +}; + // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. @@ -63,7 +106,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, const std::unordered_map* - hlo_to_profile_idx); + hlo_to_profile_idx, + llvm::TargetMachine* target_machine); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -96,7 +140,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, @@ -106,9 +150,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* rhs) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; - Status HandleOutfeed(HloInstruction* infeed) override; + Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort, HloInstruction* operand) override; Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, @@ -192,6 +238,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of the computation + // function being emitted by this emitter. + llvm::Value* GetDynamicLoopBound(const int64 offset); + // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -262,6 +313,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); + // Emit IR to perform a computation for every element in a partition/slice of + // 'target_shape'. The loop bounds for the outer-dimension partitions are + // passed into the compute function as a runtime argument (accessible from + // GetDynamicLoopBound). + Status EmitParallelTargetElementLoop( + const Shape& target_shape, + const llvm_ir::ElementGenerator& element_generator, + llvm_ir::IrArray* target_array); + // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -271,7 +331,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Emit IR to compute the target address of the buffer for the given op. // The returned Value is a pointer to a IR type that represents the op's // element type. - StatusOr EmitTargetAddressForOp(const HloInstruction* op); + StatusOr EmitTargetAddressForOp( + const HloInstruction* op, const ShapeIndex& shape_index = {}); // Structurizes "array_elements" into an MD array that represents "shape". // This is a recursive function, and "dimension_index" indicates the index of @@ -281,6 +342,71 @@ class IrEmitter : public DfsHloVisitorWithDefault { const std::vector& array_elements, const Shape& shape, int64 dimension_index); + // Tries to codegen a reduction operation using vectorized instructions. + // Returns true if successful, and false on failure. On failure, sets + // "failure_reason" to a string describing why it could not vectorize the + // reduction. + // + // TODO(sanjoy): Some of the things we do here can be abstracted out into + // concepts that generalize over other vectorizable operations. We should + // consider pulling out these abstractions into a VectorizingIrEmitter or + // something similar. + StatusOr EmitVectorizedReduce( + HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, HloComputation* function, + string* failure_reason); + + // We'd like to keep one or two one cache-line's worth of data in registers + // without generating IR with illegal (e.g. excessively large or + // non-power-of-two) vector types. We do this by introducing a layer of + // abstraction: we introduce a high level vector-like concept called a + // "sharded vector" that models data paralleism, and is mapped to a sequence + // scalar and vector llvm::Value s. + // + // For example, we can represent 29 f32 elements by a sharded vector mapped to + // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32]. + // Note that the last element is scalar. + // + // There is no requirement on the ordering or the uniqueness of the elements + // mapped to sharded vectors -- we allow repeated elements, and we allow + // elements to appear in any order. + using ShardedVector = std::vector; + + // A sharded vector type is the element-wise llvm::Type's of some + // ShardedVector. + using ShardedVectorType = std::vector; + + // Create a sharded vector type corresponding to a "element_count" long + // sequence of "element_type" values. + ShardedVectorType CreateShardedVectorType(PrimitiveType element_type, + unsigned element_count); + + // Emit LLVM IR to store the sharded vector "value_to_store" to + // "store_address". + void EmitShardedVectorStore(llvm::Value* store_address, + const ShardedVector& value_to_store, + const int alignment, + const llvm_ir::IrArray& containing_array); + + using ReductionGenerator = std ::function*, llvm::Value*, llvm::Value*)>; + + // Tries to match the reduction function "function" to a known reduction + // pattern. Returns a non-null ReductionGenerator on a successful match, + // which can be used to generate the LLVM IR corresponding to said reduction. + // On failure, this stores a reason string into "failure_reason". + ReductionGenerator MatchReductionGenerator(HloComputation* function, + string* failure_reason) const; + + // Emits the inner loop nest that runs the reduction. Helper function for + // EmitVectorizedReduce. + StatusOr EmitInnerLoopForVectorizedReduction( + const ReductionGenerator& reduction_generator, + const llvm_ir::IrArray::Index& output_index, + const ShardedVectorType& accumulator_type, HloInstruction* init_value, + HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, + unsigned element_alignment); + // Name of the computation entry function. This function serves as the // top-level "main" of the computation and will be invoked by the JIT. string entry_function_name_; @@ -319,6 +445,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; + // The number of root instruction outer dimensions used in parallel loop + // emission (EmitParallelTargetElementLoop). + int64 num_dynamic_loop_bounds_ = 0; + + // Returns whether the given instruction should be emitted as a parallel loop. + bool ShouldEmitParallelLoopFor(const HloInstruction& op) const { + // Emit parallel loop for root instruction if dynamic outer-dimension loop + // bounds were specified. + return num_dynamic_loop_bounds_ > 0 && + op.parent()->root_instruction() == &op; + } + // This struct contains all the state needed to emit instructions for // profiling a computation. class ProfilingState { @@ -404,8 +542,20 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; + enum class XfeedKind { + kInfeed, + kOutfeed, + }; + + // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program + // address. + Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, + llvm::Value* program_buffer_address); + const HloModuleConfig& hlo_module_config_; + TargetMachineFeatures target_machine_features_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index bdddca99c2f50c47ab112eda92ab1509f5448849..f0af3e7b894af875c222b184873dcc4cc9e79b8f 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -70,7 +71,7 @@ ParallelCpuExecutable::ParallelCpuExecutable( // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - uint64*); + int64*, uint64*); // Given a pointer to an output buffer (following the CPU JIT calling // conventions), mark addresses that are "live". The initial pointer itself is @@ -95,6 +96,232 @@ static void MarkLiveAddressesInOutput( } } +namespace { + +// Executor manages the concurrent execution of 'functions' for instructions +// in 'pending' on 'thread_pool' (storing resulting data in 'results'). +class Executor { + public: + Executor(const std::map& functions, + const ServiceExecutableRunOptions* run_options, + std::list* pending, + std::map* results, void** temps_array, + uint64* profile_counters_array, BufferAssignment* assignment) + : functions_(functions), + run_options_(run_options), + pending_(pending), + results_(results), + temps_array_(temps_array), + profile_counters_array_(profile_counters_array), + thread_pool_(CHECK_NOTNULL(run_options_->xla_intra_op_thread_pool())), + assignment_(assignment) {} + + // Executes pending list of instructions on thread pool. + // Returns OK status on success, error status otherwise. + Status Run(); + + private: + // Schedules a parallel invocation of compute function for 'instruction' on + // 'thread_pool_', storing result in 'result_buffer'. + // If 'partition_buffers' is non-null, parallel task will be invoked on + // per-dimension partition [start, limit) values stored in + // 'partition_buffers'. + void Schedule(HloInstruction* instruction, int64* partition_buffers, + void* result_buffer); + + // Returns true if 'instruction' has been assigned parallel tasks (returns + // false otherwise). + bool HasParallelTasks(HloInstruction* instruction); + + // Returns in 'partition_buffers' the partition [size, limit) for each + // dimension. + int64* GetPartitionBuffers( + const std::vector>& partition); + + // Returns array of result buffers for all operands in 'instruction'. + const void** GetOperandBuffers(HloInstruction* instruction); + + // Arguments passed into Executor. + const std::map& functions_; + const ServiceExecutableRunOptions* run_options_; + std::list* pending_; + std::map* results_; + void** temps_array_; + uint64* profile_counters_array_; + tensorflow::thread::ThreadPool* thread_pool_; + BufferAssignment* assignment_; + + // Members used to manage instruction execution. + tensorflow::mutex completion_queue_lock_; + tensorflow::condition_variable completion_queue_cv_; + std::deque completion_queue_; + int64 instructions_in_flight_ = 0; + std::unordered_map tasks_in_flight_; +}; + +Status Executor::Run() { + while (!pending_->empty() || instructions_in_flight_ > 0) { + auto pending_it = pending_->begin(); + while (pending_it != pending_->end()) { + HloInstruction* instruction = *pending_it; + // Skip pending instructions whose operands aren't ready. + if (std::any_of(instruction->operands().begin(), + instruction->operands().end(), + [&](HloInstruction* operand) { + return !ContainsKey(*results_, operand); + })) { + ++pending_it; + continue; + } + + // Get 'result_buffer' reference to result buffer for 'instruction'. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelSlice(instruction)); + void* result_buffer = + static_cast(temps_array_[result_slice.index()]) + + result_slice.offset(); + + if (HasParallelTasks(instruction)) { + // 'instruction' has been assigned parallel task partitions. + CHECK_EQ(HloOpcode::kCall, instruction->opcode()); + HloInstruction* root = instruction->to_apply()->root_instruction(); + + // Create ShapePartitionIterator to iterate through all outer dimension + // partitions of 'instruction'. + ShapePartitionIterator partition_iterator( + root->shape(), root->outer_dimension_partitions()); + + const int64 partition_count = + partition_iterator.GetTotalPartitionCount(); + + // Record total parallel task count for 'instruction' before dispatch. + { + tensorflow::mutex_lock l(completion_queue_lock_); + tasks_in_flight_.insert(std::make_pair(instruction, partition_count)); + VLOG(2) << "Schedule PARALLEL" + << " instruction: " << instruction->name() + << " instruction.callee: " + << instruction->to_apply()->root_instruction()->name() + << " partition_count: " << partition_count; + } + + for (int64 i = 0; i < partition_count; ++i) { + // Get partition [start, limit) for each dimension. + auto partition_buffers = + GetPartitionBuffers(partition_iterator.GetPartition(i)); + Schedule(instruction, partition_buffers, result_buffer); + } + + } else { + // Set tasks in-flight to '1' for sequential instruction execution. + { + tensorflow::mutex_lock l(completion_queue_lock_); + tasks_in_flight_.insert(std::make_pair(instruction, 1)); + VLOG(2) << "Schedule SEQUENTIAL" + << " instruction: " << instruction->name() + << " instruction.callee: " + << instruction->to_apply()->root_instruction()->name(); + } + Schedule(instruction, nullptr, result_buffer); + } + + ++instructions_in_flight_; + pending_it = pending_->erase(pending_it); + } + // Wait for a completed HLO instruction to be present in the queue. We will + // pop it out of the queue and make the result available to its users. + HloInstruction* instruction; + do { + tensorflow::mutex_lock l(completion_queue_lock_); + if (completion_queue_.empty()) { + completion_queue_cv_.wait(l); + } + if (!completion_queue_.empty()) { + instruction = completion_queue_.front(); + completion_queue_.pop_front(); + break; + } + } while (1); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelSlice(instruction)); + void* result_buffer = + static_cast(temps_array_[result_slice.index()]) + + result_slice.offset(); + InsertOrDie(results_, instruction, result_buffer); + --instructions_in_flight_; + } + return Status::OK(); +} + +void Executor::Schedule(HloInstruction* instruction, int64* partition_buffers, + void* result_buffer) { + // The thread pool entry takes ownership of |operand_buffers|. + auto operand_buffers = GetOperandBuffers(instruction); + + auto function = FindOrDie(functions_, instruction); + const auto* exec_run_options = &run_options_->run_options(); + thread_pool_->Schedule([this, instruction, result_buffer, operand_buffers, + partition_buffers, exec_run_options, function]() { + function(result_buffer, exec_run_options, operand_buffers, temps_array_, + partition_buffers, profile_counters_array_); + + delete[] operand_buffers; + delete[] partition_buffers; + // Push the completed HLO instruction on the queue, the main + // thread will pop it off and potentially launch more work which + // uses the result. + // TODO(b/27458679) Consider alternative task scheduling and synchronization + // schemes. For example, we could avoid the overhead associate with the + // condvar here if the thread just dequed the next instruction to execute + // on completion. + { + tensorflow::mutex_lock l(completion_queue_lock_); + // Decrement in-flight task count for this completion. + if (--FindOrDie(tasks_in_flight_, instruction) == 0) { + completion_queue_.push_back(instruction); + completion_queue_cv_.notify_all(); + tasks_in_flight_.erase(instruction); + } + } + }); +} + +int64* Executor::GetPartitionBuffers( + const std::vector>& partition) { + // Return in 'partition_buffers' partition [size, limit) for each dimension. + auto partition_buffers = new int64[partition.size() * 2]; + for (int i = 0; i < partition.size(); ++i) { + partition_buffers[2 * i + 0] = partition[i].first; + partition_buffers[2 * i + 1] = partition[i].first + partition[i].second; + } + return partition_buffers; +} + +bool Executor::HasParallelTasks(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCall && + !instruction->to_apply() + ->root_instruction() + ->outer_dimension_partitions() + .empty(); +} + +const void** Executor::GetOperandBuffers(HloInstruction* instruction) { + // We cannot use a move-only RAII type like std::unique_ptr because the + // list of operands is allocated on the main thread and transferred to the + // worker via the lambda passed to enqueue_function. In order for the + // lambda to take ownership, we would need to use generalized lambda + // capture which is a feature new to C++14. + // TODO(b/27458679) Avoid dynamic allocations in Executor. + auto operand_buffers = new const void*[instruction->operand_count()]; + std::transform(instruction->operands().begin(), instruction->operands().end(), + operand_buffers, [this](HloInstruction* operand) { + return FindOrDie(*results_, operand); + }); + return operand_buffers; +} + +} // namespace + Status ParallelCpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, std::vector* buffers) { @@ -180,8 +407,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( HloInstruction* instruction = entry.first; llvm::JITSymbol sym = jit_->FindSymbol(entry.second); TF_RET_CHECK(sym); - InsertOrDie(&functions, instruction, - reinterpret_cast(sym.getAddress())); + InsertOrDie( + &functions, instruction, + reinterpret_cast(cantFail(sym.getAddress()))); } // Map containing pointers to result buffers for each instruction. @@ -210,88 +438,16 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } } - void** temps_array = buffer_pointers.data(); - uint64* profile_counters_array = profile_counters.data(); - auto* thread_pool = CHECK_NOTNULL(run_options->xla_intra_op_thread_pool()); - tensorflow::mutex completion_queue_lock; - tensorflow::condition_variable completion_queue_cv; - std::deque completion_queue; - int64 instructions_in_flight = 0; - while (!pending.empty() || instructions_in_flight > 0) { - auto pending_it = pending.begin(); - while (pending_it != pending.end()) { - HloInstruction* instruction = *pending_it; - // Skip pending instructions whose operands aren't ready. - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), - [&](HloInstruction* operand) { - return !ContainsKey(results, operand); - })) { - ++pending_it; - continue; - } + // TODO(b/27458679) Manage scheduling based on in-flight concurrency limits. + // For example, if we expect a library conv/matmul call to run at max + // concurrency, we should not dispatch runnable instructions until the + // library call is finished (to avoid expensive cache invalidation). + Executor executor(functions, run_options, &pending, &results, + buffer_pointers.data(), profile_counters.data(), + assignment_.get()); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array[result_slice.index()]) + - result_slice.offset(); - // We cannot use a move-only RAII type like std::unique_ptr because the - // list of operands is allocated on the main thread and transferred to the - // worker via the lambda passed to enqueue_function. In order for the - // lambda to take ownership, we would need to use generalized lambda - // capture which is a feature new to C++14. - auto operand_buffers = new const void*[instruction->operand_count()]; - std::transform(instruction->operands().begin(), - instruction->operands().end(), operand_buffers, - [&results](HloInstruction* operand) { - return FindOrDie(results, operand); - }); - auto function = FindOrDie(functions, instruction); - // The thread pool entry takes ownership of |operand_buffers|. - const auto* exec_run_options = &run_options->run_options(); - thread_pool->Schedule([instruction, &completion_queue, - &completion_queue_lock, &completion_queue_cv, - result_buffer, exec_run_options, operand_buffers, - temps_array, profile_counters_array, function] { - function(result_buffer, exec_run_options, operand_buffers, temps_array, - profile_counters_array); - delete[] operand_buffers; - // Push the completed HLO instruction on the queue, the main thread - // will pop it off and potentially launch more work which uses the - // result. - { - tensorflow::mutex_lock l(completion_queue_lock); - completion_queue.push_back(instruction); - completion_queue_cv.notify_all(); - } - }); + TF_RETURN_IF_ERROR(executor.Run()); - ++instructions_in_flight; - pending_it = pending.erase(pending_it); - } - // Wait for a completed HLO instruction to be present in the queue. We will - // pop it out of the queue and make the result available to its users. - HloInstruction* instruction; - do { - tensorflow::mutex_lock l(completion_queue_lock); - if (completion_queue.empty()) { - completion_queue_cv.wait(l); - } - if (!completion_queue.empty()) { - instruction = completion_queue.front(); - completion_queue.pop_front(); - break; - } - } while (1); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array[result_slice.index()]) + - result_slice.offset(); - InsertOrDie(&results, instruction, result_buffer); - --instructions_in_flight; - } uint64 end_micros = tensorflow::Env::Default()->NowMicros(); { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index 6d5f790c3941af5cac098fd39c1dace5564cee5b..a3fe2657989ef2a7bd001e49d1baab57b3def839 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -89,6 +89,12 @@ class ParallelCpuExecutable : public Executable { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } + const Status EqualOrFail(const Executable& executable) { + // TODO(b/62952745) Implement equality test on CPU parallel executable. + return Unimplemented( + "Equality test on CPU parallel executable is not implemented."); + } + private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 8f1ce82d49a1c7cabfb62bf30e69faedc0318138..b3f4609d465efb4df8921abb684bafd263fe040f 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -38,13 +38,12 @@ int main(int argc, char** argv) { // Transfer parameters. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + xla::Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + std::unique_ptr param1_literal = xla::Literal::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = client->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -69,7 +68,7 @@ int main(int argc, char** argv) { LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", profile.compute_time_ns()); - LOG(INFO) << xla::LiteralUtil::ToString(*actual); + LOG(INFO) << actual->ToString(); return 0; } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc new file mode 100644 index 0000000000000000000000000000000000000000..61b408b8c24dded134218110d4e219c31f1685a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" + +namespace xla { +namespace cpu { + +std::vector ShapePartitionAssigner::Run(int64 target_partition_count) { + // Gather outer-most dims where dim_size >= 'target_partition_count'. + // Note: always leave inner-dim static for vectorization/optimizations. + std::vector outer_dims; + int64 outer_dim_size = 1; + // TODO(b/27458679) Consider reserving enough minor dimensions (based on + // target vector register width) to enable vector instructions. + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 1; --i) { + const int64 dimension = shape_.layout().minor_to_major(i); + outer_dims.push_back(dimension); + outer_dim_size *= shape_.dimensions(dimension); + if (outer_dim_size >= target_partition_count) { + break; + } + } + + // Clip target partition count if outer dim size is insufficient to cover. + target_partition_count = std::min(outer_dim_size, target_partition_count); + + // Calculate the target number of partitions per-dimension, by factoring + // 'target_partition_count' into 'num_outer_dims' equal terms. + // EX: + // *) target_partition_count = 16 + // *) out_dim_count = 2 + // *) target_dim_partition_count = 16 ^ (1.0 / 2) == 4 + const int64 target_dim_partition_count = std::pow( + static_cast(target_partition_count), 1.0 / outer_dims.size()); + + // Assign feasible dimension partitions based on 'target_dim_partition_count' + // and actual dimension sizes from 'shape_'. + std::vector dimension_partition_counts(outer_dims.size()); + for (int64 i = 0; i < outer_dims.size(); ++i) { + dimension_partition_counts[i] = + std::min(static_cast(shape_.dimensions(outer_dims[i])), + target_dim_partition_count); + } + + // Check if total partition count is below 'target_partition_count'. + // This can occur if some dimensions in 'shape_' are below the + // 'target_dim_partition_count' threshold. + if (GetTotalPartitionCount(dimension_partition_counts) < + target_partition_count) { + // Assign additional partitions (greedily to outer dimensions), if doing + // so would keep the total number of partitions <= 'target_partition_count', + // using one pass over 'dimension_partition_counts'. + for (int64 i = 0; i < dimension_partition_counts.size(); ++i) { + const int64 current_dim_partition_count = dimension_partition_counts[i]; + const int64 other_dims_partition_count = + GetTotalPartitionCount(dimension_partition_counts) / + current_dim_partition_count; + // Constraint: (current + additional) * other <= target + // Calculate: additional = target / other - current + int64 additional_partition_count = + target_partition_count / other_dims_partition_count - + current_dim_partition_count; + // Clip 'additional_partition_count' by current dimension size. + additional_partition_count = std::min( + shape_.dimensions(outer_dims[i]) - dimension_partition_counts[i], + additional_partition_count); + if (additional_partition_count > 0) { + dimension_partition_counts[i] += additional_partition_count; + } + } + } + + return dimension_partition_counts; +} + +int64 ShapePartitionAssigner::GetTotalPartitionCount( + const std::vector& dimension_partition_counts) { + int64 total_partition_count = 1; + for (int64 dim_partition_count : dimension_partition_counts) { + total_partition_count *= dim_partition_count; + } + return total_partition_count; +} + +ShapePartitionIterator::ShapePartitionIterator( + const Shape& shape, const std::vector& dimension_partition_counts) + : shape_(shape), + dimension_partition_counts_(dimension_partition_counts), + dimensions_(dimension_partition_counts_.size()), + dimension_partition_sizes_(dimension_partition_counts_.size()), + dimension_partition_strides_(dimension_partition_counts_.size()) { + // Store partitioned outer dimensions from 'shape_'. + for (int i = 0; i < dimensions_.size(); ++i) { + dimensions_[i] = shape_.layout().minor_to_major( + shape_.layout().minor_to_major_size() - 1 - i); + } + + // Calculate partition size for each dimension (note that the size of + // the last partition in each dimension may be different if the dimension + // size is not a multiple of partition size). + for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { + const int64 dim_size = shape_.dimensions(dimensions_[i]); + dimension_partition_sizes_[i] = + std::max(1LL, dim_size / dimension_partition_counts_[i]); + } + + // Calculate the partition strides for each dimension. + dimension_partition_strides_[dimension_partition_strides_.size() - 1] = 1; + for (int i = dimension_partition_strides_.size() - 2; i >= 0; --i) { + dimension_partition_strides_[i] = dimension_partition_strides_[i + 1] * + dimension_partition_counts_[i + 1]; + } +} + +std::vector> ShapePartitionIterator::GetPartition( + int64 index) const { + // Calculate and return the partition for 'index'. + // Returns for each dimension: (partition_start, partition_size). + std::vector> partition(dimensions_.size()); + for (int64 i = 0; i < partition.size(); ++i) { + // Calculate the index for dimension 'i'. + const int64 partition_index = index / dimension_partition_strides_[i]; + // Calculate dimension partition start at 'partition_index'. + partition[i].first = partition_index * dimension_partition_sizes_[i]; + // Calculate dimension partition size (note that the last partition size + // may be adjusted if dimension size is not a multiple of partition size). + if (partition_index == dimension_partition_counts_[i] - 1) { + // Last partition in this dimension. + partition[i].second = + shape_.dimensions(dimensions_[i]) - partition[i].first; + } else { + partition[i].second = dimension_partition_sizes_[i]; + } + CHECK_GT(partition[i].second, 0); + // Update index to remove conribution from current dimension. + index -= partition_index * dimension_partition_strides_[i]; + } + return partition; +} + +int64 ShapePartitionIterator::GetTotalPartitionCount() const { + return ShapePartitionAssigner::GetTotalPartitionCount( + dimension_partition_counts_); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h new file mode 100644 index 0000000000000000000000000000000000000000..7a2d00421cfdc8e41ec48698a16665621de16bda --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ + +#include + +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { +namespace cpu { + +// ShapePartitionAssigner partitions the most-major dimensions of 'shape' such +// that the total partition count <= 'target_partition_count'. +// +// Example 1: +// +// Let 'shape' = [8, 16, 32] and 'target_partition_count' = 6. +// +// Because the most-major dimension size is <= 'target_partition_count', we +// can generate our target number of partitions by partition the most-major +// dimensions. +// +// This will result in the following partitions of the most-major dimension: +// +// [0, 1), [1, 2), [2, 3), [3, 4), [4, 5) [5, 8) +// +// Note that the last partition has residule because the dimension size is +// not a multiple of the partition count. +// +// +// Example 2: +// +// Let 'shape' = [8, 16, 32] and 'target_partition_count' = 16. +// +// Because the most-major dimension only has size 8, we must also partition +// the next most-major dimension to generate the target of 16 partitions. +// We factor 'target_partition_count' by the number of most-major dimensions +// we need to partition, to get a per-dimension target partition count: +// +// target_dimension_partition_count = 16 ^ (1 / 2) == 4 +// +// This will result in the following partitions of the most-major dimension: +// +// [0, 2), [2, 4), [4, 6), [6, 8) +// +// This will result in the following partitions of the second most-major +// dimension: +// +// [0, 4), [4, 8), [8, 12), [12, 16) +// +class ShapePartitionAssigner { + public: + ShapePartitionAssigner(const Shape& shape) : shape_(shape) {} + + // Returns dimension partition counts (starting at outer-most dimension). + std::vector Run(int64 target_partition_count); + + // Returns the total partition count based on 'dimension_partition_counts'. + static int64 GetTotalPartitionCount( + const std::vector& dimension_partition_counts); + + private: + const Shape& shape_; +}; + +// ShapePartitionIterator iterates through outer-dimension partitions of +// 'shape' as specified by 'dimension_partition_counts'. +class ShapePartitionIterator { + public: + ShapePartitionIterator(const Shape& shape, + const std::vector& dimension_partition_counts); + + // Returns a partition [start, size] for each dimension. + // Partitions are listed starting from outer-most dimension first. + std::vector> GetPartition(int64 index) const; + + int64 GetTotalPartitionCount() const; + + private: + const Shape& shape_; + const std::vector dimension_partition_counts_; + + std::vector dimensions_; + std::vector dimension_partition_sizes_; + std::vector dimension_partition_strides_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee0c53fa6d7c41481a53350e57e5844dea2644c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" + +#include +#include + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace cpu { +namespace { + +class ShapePartitionAssignerTest : public HloTestBase { + protected: + typedef std::vector Vec; + + void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) { + ShapePartitionAssigner assigner(shape); + // Check all partitions of outer dimension. + for (int64 i = 1; i <= expected_max_partition_count; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), + assigner.Run(/*target_partition_count=*/i))); + } + // Check target_partition_count > outer dimension size. + EXPECT_TRUE(ContainersEqual( + Vec({expected_max_partition_count}), + assigner.Run( + /*target_partition_count=*/expected_max_partition_count + 1))); + } +}; + +TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1); +} + +TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1); +} + +TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5); +} + +TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3); +} + +TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); + ShapePartitionAssigner assigner(shape); + + for (int64 i = 1; i <= 5; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( + /*target_partition_count=*/i))); + } + + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); + EXPECT_TRUE( + ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/10))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/11))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/12))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/13))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/14))); + EXPECT_TRUE(ContainersEqual(Vec({5, 3}), + assigner.Run(/*target_partition_count=*/15))); + EXPECT_TRUE(ContainersEqual(Vec({5, 3}), + assigner.Run(/*target_partition_count=*/16))); +} + +TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}); + ShapePartitionAssigner assigner(shape); + + for (int64 i = 1; i <= 3; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( + /*target_partition_count=*/i))); + } + + EXPECT_TRUE( + ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4))); + EXPECT_TRUE( + ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/10))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/11))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/12))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/13))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/14))); + EXPECT_TRUE(ContainersEqual(Vec({3, 5}), + assigner.Run(/*target_partition_count=*/15))); + EXPECT_TRUE(ContainersEqual(Vec({3, 5}), + assigner.Run(/*target_partition_count=*/16))); +} + +class ShapePartitionIteratorTest : public HloTestBase { + protected: + typedef std::vector> Partition; +}; + +TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}); + + { + ShapePartitionIterator iterator(shape, {1}); + EXPECT_EQ(1, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0))); + } + + { + ShapePartitionIterator iterator(shape, {2}); + EXPECT_EQ(2, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0))); + EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1))); + } + + { + ShapePartitionIterator iterator(shape, {3}); + EXPECT_EQ(3, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1))); + EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2))); + } +} + +TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); + + { + ShapePartitionIterator iterator(shape, {1, 1}); + EXPECT_EQ(1, iterator.GetTotalPartitionCount()); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); + } + + { + ShapePartitionIterator iterator(shape, {2, 2}); + EXPECT_EQ(4, iterator.GetTotalPartitionCount()); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); + EXPECT_TRUE( + ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); + EXPECT_TRUE( + ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); + } +} + +class RandomShapePartitionIteratorTest : public HloTestBase { + protected: + typedef std::vector> Partition; + RandomShapePartitionIteratorTest() + : generator_(rd_()), distribution_(1, 10) {} + + std::vector RandR4Dims() { return {Rand(), Rand(), Rand(), Rand()}; } + + int64 Rand() { return distribution_(generator_); } + + std::random_device rd_; + std::mt19937 generator_; + std::uniform_int_distribution distribution_; +}; + +TEST_F(RandomShapePartitionIteratorTest, RandomShapeAndPartitions) { + // Choose random dimensions for R4 shape. + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, RandR4Dims(), {3, 2, 1, 0}); + // Choose random number of outer dimensions to partition. + const int num_outer_dims_to_partition = 1 + (Rand() % 3); + // Choose random outer dimension partition counts. + std::vector dim_sizes(num_outer_dims_to_partition); + std::vector dim_partition_counts(num_outer_dims_to_partition); + int64 total_dim_size = 1; + for (int i = 0; i < num_outer_dims_to_partition; ++i) { + const int64 dimension = shape.layout().minor_to_major( + shape.layout().minor_to_major_size() - 1 - i); + dim_sizes[i] = shape.dimensions(dimension); + total_dim_size *= dim_sizes[i]; + // Choose dimension partition count in [1, dim_size] + const int64 dim_partition_count = 1 + Rand() % dim_sizes[i]; + dim_partition_counts[i] = dim_partition_count; + } + // Iterate through all partition: for each partition record covered + // index ranges by dimension. + std::vector> ranges(num_outer_dims_to_partition); + ShapePartitionIterator partition_iterator(shape, dim_partition_counts); + const int64 partition_count = partition_iterator.GetTotalPartitionCount(); + for (int64 i = 0; i < partition_count; ++i) { + const auto& dim_partition = partition_iterator.GetPartition(i); + for (int dim = 0; dim < dim_partition.size(); ++dim) { + ranges[dim].insert( + std::make_pair(dim_partition[dim].first, + dim_partition[dim].first + dim_partition[dim].second)); + } + } + // Check that partitions cover entire dimension size range (for each + // partitioned dimension). + for (int i = 0; i < ranges.size(); ++i) { + int64 expected_index = 0; + for (auto& r : ranges[i]) { + EXPECT_EQ(expected_index, r.first); + expected_index = r.second; + } + EXPECT_EQ(expected_index, dim_sizes[i]); + } +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 7c74912a7ab9c388c9911fe8194f268623f0abd1..262c471b4079b92daee98095b2fa61834cb2f243 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -21,11 +21,12 @@ limitations under the License. #include #include +#include "external/llvm/include/llvm/ExecutionEngine/ExecutionEngine.h" +#include "external/llvm/include/llvm/ExecutionEngine/SectionMemoryManager.h" #include "external/llvm/include/llvm/IR/Mangler.h" #include "external/llvm/include/llvm/Support/CodeGen.h" #include "external/llvm/include/llvm/Support/Host.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" @@ -41,7 +42,7 @@ namespace cpu { namespace { // Converts a symbol 'name' into the form expected by dlsym(). -std::string CanonicalizeSymbol(const std::string &name) { +std::string CanonicalizeSymbol(const std::string& name) { #if defined(__APPLE__) // On Mac OS X, dlsym() expects names not to be prefixed with a leading // underscore. @@ -52,47 +53,77 @@ std::string CanonicalizeSymbol(const std::string &name) { return name; } +class JITSymbolTable { + public: + JITSymbolTable() { Populate(); } + + void* Lookup(llvm::StringRef jit_symbol_name) const { + auto it = jit_symbol_table_.find(jit_symbol_name); + return it == jit_symbol_table_.end() ? nullptr : it->getValue(); + } + + static bool MustBeInTable(llvm::StringRef name) { + // In particular, names starting with + // runtime::kXlaCpuRuntimeSymbolNamePrefix should not be dlsym'ed. + return name.startswith(runtime::kXlaCpuRuntimeSymbolNamePrefix); + } + + private: + void AddJITSymbolToTable(llvm::StringRef jit_symbol_name, + llvm::StringRef cpp_symbol_name, + void* jit_symbol_value) { + // The JIT symbol name and the C++ symbol name (with an extern "C" linkage) + // need to match, otherwise AOT links will fail. + CHECK(jit_symbol_name == cpp_symbol_name); + CHECK(jit_symbol_table_.insert({jit_symbol_name, jit_symbol_value}).second); + } + + void Populate() { +#define ADD_JIT_SYMBOL_TO_TABLE(base_name) \ + do { \ + AddJITSymbolToTable( \ + xla::cpu::runtime::k##base_name##SymbolName, \ + "__xla_cpu_runtime_" #base_name, \ + reinterpret_cast(__xla_cpu_runtime_##base_name)); \ + } while (false) + + ADD_JIT_SYMBOL_TO_TABLE(AcquireInfeedBufferForDequeue); + ADD_JIT_SYMBOL_TO_TABLE(ReleaseInfeedBufferAfterDequeue); + ADD_JIT_SYMBOL_TO_TABLE(AcquireOutfeedBufferForPopulation); + ADD_JIT_SYMBOL_TO_TABLE(ReleaseOutfeedBufferAfterPopulation); + ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32); + ADD_JIT_SYMBOL_TO_TABLE(LogV8F32); + ADD_JIT_SYMBOL_TO_TABLE(TanhV8F32); + ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32); + ADD_JIT_SYMBOL_TO_TABLE(LogV4F32); + ADD_JIT_SYMBOL_TO_TABLE(TanhV4F32); + ADD_JIT_SYMBOL_TO_TABLE(EigenConvF32); + ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF32); + ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF64); + ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedConvF32); + ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF32); + ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF64); + +#undef ADD_JIT_SYMBOL_TO_TABLE + } + + llvm::StringMap jit_symbol_table_; +}; + +const JITSymbolTable& GetJITSymbolTable() { + static JITSymbolTable* symbol_table = new JITSymbolTable; + return *symbol_table; +} + // A simple SymbolResolver that delegates to the host dynamic linker. struct SimpleResolver : public llvm::JITSymbolResolver { - llvm::JITSymbol findSymbol(const std::string &name) override { - void *func_addr = nullptr; - + llvm::JITSymbol findSymbol(const std::string& name) override { std::string canonical_name = CanonicalizeSymbol(name); - if (canonical_name == runtime::kEigenMatmulF32SymbolName) { - func_addr = reinterpret_cast(__xla_cpu_runtime_EigenMatMulF32); - } else if (canonical_name == - runtime::kEigenSingleThreadedMatmulF32SymbolName) { - func_addr = reinterpret_cast( - __xla_cpu_runtime_EigenSingleThreadedMatMulF32); - } else if (canonical_name == runtime::kEigenConvF32SymbolName) { - func_addr = reinterpret_cast(__xla_cpu_runtime_EigenConvF32); - } else if (canonical_name == - runtime::kEigenSingleThreadedConvF32SymbolName) { - func_addr = reinterpret_cast( - __xla_cpu_runtime_EigenSingleThreadedConvF32); - } else if (canonical_name == - runtime::kAcquireInfeedBufferForDequeueSymbolName) { - func_addr = reinterpret_cast( - __xla_cpu_runtime_AcquireInfeedBufferForDequeue); - } else if (canonical_name == - runtime::kReleaseInfeedBufferAfterDequeueSymbolName) { - func_addr = reinterpret_cast( - __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue); - } else if (canonical_name == runtime::kExpV4F32) { - func_addr = reinterpret_cast(runtime::ExpV4F32); - } else if (canonical_name == runtime::kExpV8F32) { - func_addr = reinterpret_cast(runtime::ExpV8F32); - } else if (canonical_name == runtime::kLogV4F32) { - func_addr = reinterpret_cast(runtime::LogV4F32); - } else if (canonical_name == runtime::kLogV8F32) { - func_addr = reinterpret_cast(runtime::LogV8F32); - } else if (canonical_name == runtime::kTanhV4F32) { - func_addr = reinterpret_cast(runtime::TanhV4F32); - } else if (canonical_name == runtime::kTanhV8F32) { - func_addr = reinterpret_cast(runtime::TanhV8F32); - } else { - func_addr = dlsym(RTLD_DEFAULT, canonical_name.c_str()); - } + const JITSymbolTable& jit_symbol_table = GetJITSymbolTable(); + + void* func_addr = JITSymbolTable::MustBeInTable(canonical_name) + ? jit_symbol_table.Lookup(canonical_name) + : dlsym(RTLD_DEFAULT, canonical_name.c_str()); if (func_addr == nullptr) { return nullptr; @@ -101,7 +132,7 @@ struct SimpleResolver : public llvm::JITSymbolResolver { llvm::JITSymbolFlags::None); return symbol_info; } - llvm::JITSymbol findSymbolInLogicalDylib(const std::string &name) override { + llvm::JITSymbol findSymbolInLogicalDylib(const std::string& name) override { return nullptr; } }; @@ -110,7 +141,7 @@ llvm::SmallVector DetectMachineAttributes() { llvm::SmallVector result; llvm::StringMap host_features; if (llvm::sys::getHostCPUFeatures(host_features)) { - for (auto &feature : host_features) { + for (auto& feature : host_features) { if (feature.second) { llvm::StringRef feature_name = feature.first(); // Skip avx512 for now, it isn't quite ready in LLVM. @@ -133,15 +164,17 @@ llvm::StringRef GetHostCpuName() { CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { CompilerFunctor::VectorIntrinsics intrinsics; - intrinsics.sse_intrinsics = (&runtime::ExpV4F32 != nullptr); - intrinsics.avx_intrinsics = (&runtime::ExpV8F32 != nullptr); + intrinsics.sse_intrinsics = (&__xla_cpu_runtime_ExpV4F32 != nullptr); + intrinsics.avx_intrinsics = (&__xla_cpu_runtime_ExpV8F32 != nullptr); return intrinsics; } } // namespace -SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options, - llvm::CodeGenOpt::Level opt_level) +SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, + CompilerFunctor::ModuleHook pre_optimization_hook, + CompilerFunctor::ModuleHook post_optimization_hook) : target_machine_( CHECK_NOTNULL(llvm::EngineBuilder() .setTargetOptions(target_options) @@ -152,33 +185,33 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), + object_layer_( + [] { return std::make_shared(); }), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, - opt_level, GetAvailableIntrinsics())) { + opt_level, GetAvailableIntrinsics(), + std::move(pre_optimization_hook), + std::move(post_optimization_hook))) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( std::unique_ptr module) { - // The Orc API adds a whole iterable "set" of modules, so we wrap the module - // in a vector. - std::vector> module_set; - module_set.push_back(std::move(module)); - auto handle = compile_layer_.addModuleSet( - std::move(module_set), MakeUnique(), - MakeUnique()); + auto handle = cantFail(compile_layer_.addModule( + std::move(module), MakeUnique())); module_handles_.push_back(handle); return handle; } void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::ModuleHandleT handle) { module_handles_.erase( - std::remove(module_handles_.begin(), module_handles_.end(), handle)); - compile_layer_.removeModuleSet(handle); + std::remove(module_handles_.begin(), module_handles_.end(), handle), + module_handles_.end()); + cantFail(compile_layer_.removeModule(handle)); } -llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string &name) { +llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { std::string mangled_name; { llvm::raw_string_ostream mangled_name_stream(mangled_name); @@ -187,7 +220,7 @@ llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string &name) { // Resolve symbol from last module to first, allowing later redefinitions of // symbols shadow earlier ones. - for (auto &handle : + for (auto& handle : llvm::make_range(module_handles_.rbegin(), module_handles_.rend())) { if (auto symbol = compile_layer_.findSymbolIn(handle, mangled_name, diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 4d8653484a037a345321dbe11c384f650e0142d0..f57049c9dde23c7ac540d22dde0b7be9be38e2e0 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -25,6 +25,7 @@ limitations under the License. #include "external/llvm/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/types.h" @@ -41,9 +42,12 @@ namespace cpu { // it's added to the JIT. class SimpleOrcJIT { public: - using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer<>; - using CompileLayerT = llvm::orc::IRCompileLayer; - using ModuleHandleT = CompileLayerT::ModuleSetHandleT; + using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; + using CompileFtor = + std::function( + llvm::Module&)>; + using CompileLayerT = llvm::orc::IRCompileLayer; + using ModuleHandleT = CompileLayerT::ModuleHandleT; // Create a new JIT, targeting the host architecture. // The |target_options| parameter allows customization of certain code @@ -51,8 +55,14 @@ class SimpleOrcJIT { // can be reassociated, etc.). // The |opt_level| parameter controls the optimization level of the code // generator. + // The |pre_optimization_hook| is invoked on the module before any IR + // level optimizations are applied. + // The |post_optimization_hook| is invoked on the module after all IR + // level optimizations are applied. SimpleOrcJIT(const llvm::TargetOptions& target_options, - llvm::CodeGenOpt::Level opt_level); + llvm::CodeGenOpt::Level opt_level, + CompilerFunctor::ModuleHook pre_optimization_hook, + CompilerFunctor::ModuleHook post_optimization_hook); // Data layout this JIT was created with. const llvm::DataLayout& data_layout() const { return data_layout_; } @@ -73,6 +83,8 @@ class SimpleOrcJIT { // nullptr if the symbol cannot be found. llvm::JITSymbol FindSymbol(const std::string& name); + llvm::TargetMachine* target_machine() const { return target_machine_.get(); } + private: std::vector module_handles_; std::unique_ptr target_machine_; diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc similarity index 56% rename from tensorflow/compiler/xla/service/cpu/infeed_manager.cc rename to tensorflow/compiler/xla/service/cpu/xfeed_manager.cc index 14c882a06ee9fdfc66f3d6db55146431634dd85e..2160c3cd01df0b359f986d5e5ba16c71e2676584 100644 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc @@ -13,32 +13,37 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace cpu { namespace runtime { -InfeedBuffer::~InfeedBuffer() = default; - -InfeedManager::InfeedManager() : current_buffer_(nullptr) {} +void XfeedManager::Reset() { + infeed()->Reset(); + outfeed()->Reset(); +} -void InfeedManager::Reset() { +void XfeedQueueManager::Reset() { tensorflow::mutex_lock l(mu_); - CHECK(!current_buffer_); - for (auto buffer : enqueued_buffer_) { - buffer->Done(); + CHECK(current_buffer_ == nullptr); + for (auto buffer : enqueued_buffers_) { + buffer->Done(ShapeUtil::MakeNil()); } - enqueued_buffer_.clear(); + enqueued_buffers_.clear(); } -void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) { +void XfeedQueueManager::EnqueueBuffers( + tensorflow::gtl::ArraySlice buffers) { tensorflow::mutex_lock l(mu_); - bool was_empty = enqueued_buffer_.empty(); - enqueued_buffer_.push_back(buffer); - if (was_empty) { + bool was_empty = enqueued_buffers_.empty(); + for (XfeedBuffer* b : buffers) { + enqueued_buffers_.push_back(b); + } + if (was_empty && !buffers.empty()) { // This has the potential to suffer from the notified thread // immediately trying and failing to acquire mu_, but seems // preferable to the alternative of notifying outside the lock @@ -47,23 +52,24 @@ void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) { } } -InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { +XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { + while (enqueued_buffers_.empty()) { cv_.wait(l); } - CHECK(!current_buffer_); - current_buffer_ = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); + CHECK(current_buffer_ == nullptr); + current_buffer_ = enqueued_buffers_.front(); + enqueued_buffers_.pop_front(); return current_buffer_; } -void InfeedManager::ReleaseCurrentBuffer(int32 length, void* data) { +void XfeedQueueManager::ReleaseCurrentBuffer(int32 length, void* data, + StatusOr shape) { tensorflow::mutex_lock l(mu_); - CHECK(current_buffer_); + CHECK(current_buffer_ != nullptr); CHECK_EQ(length, current_buffer_->length()); CHECK_EQ(data, current_buffer_->data()); - current_buffer_->Done(); + current_buffer_->Done(std::move(shape)); current_buffer_ = nullptr; } diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h similarity index 50% rename from tensorflow/compiler/xla/service/cpu/infeed_manager.h rename to tensorflow/compiler/xla/service/cpu/xfeed_manager.h index 77472746e659b2ddbd9b54a036775ebdd0084fdd..86af789384e0a926b2e469daac68b6e1521875bc 100644 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -17,12 +17,15 @@ limitations under the License. // is used by the CPU runtime to transfer buffers into an executing // CPU computation, e.g., to feed data into a while loop. -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_ #include +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -31,62 +34,89 @@ namespace runtime { // Abstract class defining an infeed buffer that is passed to the // runtime by a client. The client manages the storage of the buffer. -class InfeedBuffer { +class XfeedBuffer { public: - virtual ~InfeedBuffer(); + virtual ~XfeedBuffer() = default; virtual int32 length() = 0; virtual void* data() = 0; - virtual void Done() = 0; + + // The 'shape' parameter reflects what shape the embedded program was + // expecting / producing with respect to this XfeedBuffer. E.g. this will + // contain information about the layout of an outfed buffer. + virtual void Done(StatusOr shape) = 0; }; -// Client-side class used to enqueue infeed buffers. -class InfeedManager { +// Reusable component for managing the infeed and outfeed queue state. +class XfeedQueueManager { public: - InfeedManager(); + XfeedQueueManager() = default; // Calls the completion callback for any enqueued buffers that have - // not been dequeued by the runtime, and empties the infeed + // not been dequeued by the runtime, and empties the // queue. Reset may not be called while a runtime computation is // processing a dequeued buffer. The only safe way to ensure this // condition is to call Reset when no computation is taking place. void Reset(); - // Adds buffer to the infeed queue. buffer->Done will be called when - // the buffer will no longer be accessed by the InfeedManager, - // either as a result of a call to Reset or because the runtime has - // dequeued and used the buffer. - void EnqueueBuffer(InfeedBuffer* buffer); - - // Blocks until the infeed queue is non-empty, then returns the - // buffer at the head of the queue. Sets the current buffer to be - // the returned buffer. It is an error to call BlockingDequeueBuffer - // if there is an unreleased current buffer, i.e., - // ReleaseCurrentBuffer must be called between calls to + // Adds a sequence of buffers to the queue atomically. buffer->Done will be + // called when the buffer will no longer be accessed by the XfeedManager, + // either as a result of a call to Reset or because the runtime has dequeued + // and used the buffer. + void EnqueueBuffers(tensorflow::gtl::ArraySlice buffers); + + // Blocks until the queue is non-empty, then returns the buffer at the head of + // the queue. Sets the current buffer to be the returned buffer. It is an + // error to call BlockingDequeueBuffer if there is an unreleased current + // buffer, i.e., ReleaseCurrentBuffer must be called between calls to // BlockingDequeueBuffer. - InfeedBuffer* BlockingDequeueBuffer(); + XfeedBuffer* BlockingDequeueBuffer(); // Releases the current buffer, which is the last buffer returned by // BlockingDequeuBuffer and not yet released. length and data must // match the buffer->length() and buffer->data() for the current // buffer. - void ReleaseCurrentBuffer(int32 length, void* data); + // + // 'shape' communicates the shape of the buffer being released. If the program + // passed a value that could not be decoded as a shape, 'shape' will be an + // error status. In the case of outfeed, this indicates the layout of the + // shape that has been outfed. In the case of infeed, this can be used for + // sanity checking purposes. + void ReleaseCurrentBuffer(int32 length, void* data, StatusOr shape); private: tensorflow::mutex mu_; + // Condition variable that is signaled every time a buffer is // enqueued to an empty queue. tensorflow::condition_variable cv_; - // InfeedBuffer* queue contents are not owned, but buffer->Done must + + // XfeedBuffer* queue contents are not owned, but buffer->Done must // be called when the buffer is no longer needed by the runtime. - std::deque enqueued_buffer_; + std::deque enqueued_buffers_; + // If non-NULL, the buffer that is currently being processed by the // runtime. Not owned. - InfeedBuffer* current_buffer_; + XfeedBuffer* current_buffer_ = nullptr; +}; + +// Client-side class used to enqueue infeed buffers. +class XfeedManager { + public: + XfeedManager() = default; + + void Reset(); + + XfeedQueueManager* infeed() { return &infeed_; } + XfeedQueueManager* outfeed() { return &outfeed_; } + + private: + XfeedQueueManager infeed_; + XfeedQueueManager outfeed_; }; } // namespace runtime } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8defd28b013512a0b6ace0c23dff4a38fe505385 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" + +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class InfeedManagerTest : public ::testing::Test {}; + +class TestInfeedBuffer : public cpu::runtime::XfeedBuffer { + public: + explicit TestInfeedBuffer(int32 length, bool expect_shape_match = true) + : shape_(ShapeUtil::MakeShape(U8, {length})), + done_called_(false), + length_(length), + expect_shape_match_(expect_shape_match) {} + ~TestInfeedBuffer() override { EXPECT_TRUE(done_called_); } + + int32 length() override { return length_; } + void* data() override { return nullptr; } + void Done(StatusOr shape) override { + CHECK(!done_called_); + done_called_ = true; + TF_ASSERT_OK(shape.status()); + EXPECT_EQ(expect_shape_match_, ShapeUtil::Equal(shape_, shape.ValueOrDie())) + << "want " << ShapeUtil::HumanString(shape_) << " " + << (expect_shape_match_ ? "==" : "!=") << " " + << ShapeUtil::HumanString(shape.ValueOrDie()); + } + + const Shape& shape() const { return shape_; } + + private: + Shape shape_; + bool done_called_; + int32 length_; + bool expect_shape_match_; +}; + +// Performs the acquire/release sequence on the infeed, as the generated CPU +// code would in the process of executing the infeed operation. +void ProcessNextBuffer(int32 length) { + auto shape = ShapeUtil::MakeShape(U8, {length}); + string bytes = shape.SerializeAsString(); + void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue( + length, bytes.data(), bytes.size()); + __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer, + bytes.data(), bytes.size()); +} + +// Performs the acquire/release sequence on the outfeed, as the generated CPU +// code would in the process of executing the outfeed operation. +void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) { + string bytes = shape.SerializeAsString(); + void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + length, bytes.data(), bytes.size()); + __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + length, buffer, bytes.data(), bytes.size()); +} + +TEST_F(InfeedManagerTest, SingleThreadedSequential) { + TestInfeedBuffer* a = new TestInfeedBuffer(64); + TestInfeedBuffer* b = new TestInfeedBuffer(32); + + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + + xfeed->infeed()->EnqueueBuffers({a}); + xfeed->infeed()->EnqueueBuffers({b}); + ProcessNextBuffer(a->length()); + ProcessNextBuffer(b->length()); +} + +TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { + TestInfeedBuffer* a = new TestInfeedBuffer(64); + TestInfeedBuffer* b = new TestInfeedBuffer(32); + + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + + xfeed->infeed()->EnqueueBuffers({a}); + ProcessNextBuffer(a->length()); + xfeed->infeed()->EnqueueBuffers({b}); + ProcessNextBuffer(b->length()); +} + +TEST_F(InfeedManagerTest, MultiThreaded) { + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); + + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + + const int32 length = 64; + + pool.Schedule([xfeed]() { + // Spin for 100 milliseconds + int64 start_micros = tensorflow::Env::Default()->NowMicros(); + while (true) { + int64 end_micros = tensorflow::Env::Default()->NowMicros(); + if ((end_micros - start_micros) >= 100000) { // 100 ms + break; + } + } + TestInfeedBuffer* a = new TestInfeedBuffer(length); + xfeed->infeed()->EnqueueBuffers({a}); + }); + + ProcessNextBuffer(length); +} + +TEST_F(InfeedManagerTest, OutfeedWrongShape) { + TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + xfeed->outfeed()->EnqueueBuffers({b}); + + ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc index 2d9d9c7de62a34e4d18ef1d7f5552a85ad1c49cb..d8a76443a66c9234aa0b93d1d21e213fd3ba67ab 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc @@ -21,15 +21,17 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace se = ::perftools::gputools; @@ -38,7 +40,7 @@ namespace xla { namespace { -class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer { +class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer { public: explicit CpuInfeedBuffer(int32 length) : length_(length), @@ -48,7 +50,7 @@ class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer { int32 length() override { return length_; } void* data() override { return buffer_; } - void Done() override { delete this; } + void Done(StatusOr /*shape*/) override { delete this; } se::DeviceMemoryBase* device_memory() { return &device_memory_; } @@ -58,6 +60,30 @@ class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer { se::DeviceMemoryBase device_memory_; }; +class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer { + public: + CpuOutfeedBuffer(void* destination, int32 length) + : destination_(destination), length_(length) {} + + StatusOr WaitForNotification() { + done_.WaitForNotification(); + return status_; + } + + int32 length() override { return length_; } + void* data() override { return destination_; } + void Done(StatusOr shape) override { + status_ = std::move(shape); + done_.Notify(); + } + + private: + void* destination_; + int32 length_; + StatusOr status_; + tensorflow::Notification done_; +}; + } // namespace CpuTransferManager::CpuTransferManager() @@ -66,34 +92,173 @@ CpuTransferManager::CpuTransferManager() Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { const Shape& shape = literal.shape(); - VLOG(2) << "transferring literal shape to infeed: " + VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - // TODO(b/31381668) handle tuples. - if (ShapeUtil::IsTuple(shape)) { - return Unimplemented("Infeed with a tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + if (!ShapeUtil::IsTuple(shape)) { + int64 size = GetByteSizeRequirement(shape); + return TransferBufferToInfeed(executor, size, literal.InternalData()); } - cpu::runtime::InfeedManager* infeed_manager = - cpu::runtime::GetInfeedManager(); + if (ShapeUtil::IsNestedTuple(shape)) { + return Unimplemented( + "Infeed with a nested tuple shape is not supported: %s", + ShapeUtil::HumanString(literal.shape()).c_str()); + } + + // For a tuple, we transfer each of its elements to the device and + // enqueue the resulting destination device addresses with the + // infeed manager. + std::vector buffers; + buffers.reserve(literal.tuple_literals_size()); + auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { + for (cpu::runtime::XfeedBuffer* b : buffers) { + b->Done(ShapeUtil::MakeNil()); + } + }); + + for (const auto& tuple_element : literal.tuple_literals()) { + const Shape& tuple_element_shape = tuple_element.shape(); + int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); + TF_ASSIGN_OR_RETURN( + cpu::runtime::XfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, tuple_element_size, + tuple_element.InternalData())); + buffers.push_back(buffer); + } + + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + xfeed_manager->infeed()->EnqueueBuffers(buffers); + + cleanup.release(); + return Status::OK(); +} + +Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, + const void* source) { + TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, size, source)); + + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + xfeed_manager->infeed()->EnqueueBuffers({buffer}); - int64 size = GetByteSizeRequirement(shape); + return Status::OK(); +} + +StatusOr +CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, + int64 size, + const void* source) { if (size > std::numeric_limits::max()) { - return Unimplemented("Infeed shape is too large: %s needs %lld bytes", - ShapeUtil::HumanString(literal.shape()).c_str(), size); + return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); } + + if (size <= 0) { + return InvalidArgument("Infeed shape must have positive size; got %lld", + size); + } + int32 size_32 = static_cast(size); CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32); - TF_RETURN_IF_ERROR(TransferBufferToDevice( - executor, /*size=*/size, /*source=*/LiteralUtil::InternalData(literal), - queued_buffer->device_memory())); + Status s = + TransferBufferToDevice(executor, /*size=*/size, + /*source=*/source, queued_buffer->device_memory()); - infeed_manager->EnqueueBuffer(queued_buffer); + if (!s.ok()) { + queued_buffer->Done(ShapeUtil::MakeNil()); + return s; + } + return queued_buffer; +} +Status CpuTransferManager::TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) { + if (!ShapeUtil::IsTuple(literal_shape)) { + int64 size = GetByteSizeRequirement(literal_shape); + // Note: OSS build didn't like implicit conversion from + // literal_shape.dimensions() to the array slice on 2017-07-10. + tensorflow::gtl::ArraySlice dimensions( + tensorflow::bit_cast(literal_shape.dimensions().data()), + literal_shape.dimensions().size()); + auto empty = + Literal::CreateFromDimensions(literal_shape.element_type(), dimensions); + literal->Swap(empty.get()); + TF_ASSIGN_OR_RETURN(Shape received_shape, + TransferBufferFromOutfeed( + executor, size, literal->MutableInternalData())); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) + << "Shape received from outfeed " + << ShapeUtil::HumanString(received_shape) + << " did not match the shape that was requested for outfeed: " + << ShapeUtil::HumanString(literal_shape); + TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); + *literal->mutable_shape() = received_shape; + return Status::OK(); + } + + if (ShapeUtil::IsNestedTuple(literal_shape)) { + return Unimplemented( + "Nested tuple outfeeds are not yet implemented on CPU."); + } + + std::vector> elements; + for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(literal_shape, i); + // Note: OSS build didn't like implicit conversion from + // literal_shape.dimensions() to the array slice on 2017-07-10. + tensorflow::gtl::ArraySlice dimensions( + tensorflow::bit_cast( + tuple_element_shape.dimensions().data()), + tuple_element_shape.dimensions().size()); + auto empty = Literal::CreateFromDimensions( + tuple_element_shape.element_type(), dimensions); + TF_ASSIGN_OR_RETURN( + Shape received_shape, + TransferBufferFromOutfeed(executor, + GetByteSizeRequirement(tuple_element_shape), + empty->MutableInternalData())); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, tuple_element_shape)) + << "Shape received from outfeed " + << ShapeUtil::HumanString(received_shape) + << " did not match the shape that was requested for outfeed: " + << ShapeUtil::HumanString(tuple_element_shape); + TF_RET_CHECK(GetByteSizeRequirement(tuple_element_shape) == + GetByteSizeRequirement(received_shape)); + *empty->mutable_shape() = received_shape; + elements.push_back(std::move(empty)); + } + auto result = Literal::MakeTupleOwned(std::move(elements)); + literal->Swap(result.get()); + TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); return Status::OK(); } +StatusOr CpuTransferManager::TransferBufferFromOutfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + void* destination) { + if (size > std::numeric_limits::max()) { + return InvalidArgument("Outfeed shape is too large: needs %lld bytes", + size); + } + + if (size <= 0) { + return InvalidArgument("Outfeed shape must have positive size; got %lld", + size); + } + + int32 size_32 = static_cast(size); + cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + CpuOutfeedBuffer buffer(destination, size_32); + VLOG(2) << "Enqueueing outfeed buffer (for the device to populate) of length " + << size_32 << "B"; + xfeed_manager->outfeed()->EnqueueBuffers({&buffer}); + VLOG(2) << "Waiting for buffer to be notified as populated."; + return buffer.WaitForNotification(); +} + } // namespace xla static std::unique_ptr CreateCpuTransferManager() { diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu_transfer_manager.h index 727462252d7291959fd09c05c87e36411eb3ddab..30dc2d90623fb20656874d40a25b6a8449a7486c 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" @@ -37,8 +38,24 @@ class CpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; + Status TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) override; private: + // Transfers infeed data to device. InfeedBuffer->Done() must be + // called to clean up the memory allocated for InfeedBuffer. + StatusOr TransferBufferToInfeedInternal( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source); + + // On success, returns the shape that was transferred from the outfeed. + StatusOr TransferBufferFromOutfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + void* destination); + TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); }; diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index c13c86741cc4291d5ae76cb4b3d7913927c565ea..2e4b0a5230516b5308aeed892de9a49565a09f2e 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -35,7 +35,15 @@ StreamExecutorMemoryAllocator::Allocate(int device_ordinal, uint64 size, bool retry_on_failure) { TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); - return stream_executor->AllocateArray(size); + perftools::gputools::DeviceMemoryBase result = + stream_executor->AllocateArray(size); + if (size > 0 && result == nullptr) { + return ResourceExhausted( + "Failed to allocate request for %s (%lluB) on device ordinal %d", + tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, + device_ordinal); + } + return result; } tensorflow::Status StreamExecutorMemoryAllocator::Deallocate( diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 5b296861006923f438df1ad4fb5898f82f11b9e0..0f7ab111170a3152cbe86c1a4fa8d592a14d6241 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -24,51 +24,29 @@ limitations under the License. namespace xla { Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* operand) { + HloOpcode opcode) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", HloOpcodeString(opcode).c_str()); } Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { + HloOpcode opcode) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", HloOpcodeString(opcode).c_str()); } void DfsHloVisitor::SetVisiting(const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visiting: "; - CHECK(NotVisited(instruction)); + DCHECK(NotVisited(instruction)); visit_state_[&instruction] = VisitState::kVisiting; } void DfsHloVisitor::SetVisited(const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visited: "; - CHECK(NotVisited(instruction) || IsVisiting(instruction)); + DCHECK(NotVisited(instruction) || IsVisiting(instruction)); visit_state_[&instruction] = VisitState::kVisited; } -bool DfsHloVisitor::IsVisiting(const HloInstruction& instruction) { - if (visit_state_.count(&instruction) == 0) { - return false; - } - return visit_state_[&instruction] == VisitState::kVisiting; -} - -bool DfsHloVisitor::DidVisit(const HloInstruction& instruction) { - if (visit_state_.count(&instruction) == 0) { - return false; - } - return visit_state_[&instruction] == VisitState::kVisited; -} - -bool DfsHloVisitor::NotVisited(const HloInstruction& instruction) { - return visit_state_.count(&instruction) == 0 || - visit_state_[&instruction] == VisitState::kNotVisited; -} - Status DfsHloVisitor::Preprocess(HloInstruction* hlo) { return Status::OK(); } Status DfsHloVisitor::Postprocess(HloInstruction* visited) { diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 78a398f8efa870fcfbda78a769b3f6878a8a429b..e6067ae9ea244e58920d0a163078ab302327b871 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -65,43 +65,37 @@ class DfsHloVisitor { // These routines are self-descriptive, see class comment for usage // information. - virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand); - virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs); + virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode); + virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode); virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) = 0; virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) = 0; - virtual Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(maximum, HloOpcode::kMaximum, lhs, rhs); + virtual Status HandleMaximum(HloInstruction* maximum) { + return HandleElementwiseBinary(maximum, HloOpcode::kMaximum); } - virtual Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(minimum, HloOpcode::kMinimum, lhs, rhs); + virtual Status HandleMinimum(HloInstruction* minimum) { + return HandleElementwiseBinary(minimum, HloOpcode::kMinimum); } virtual Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) = 0; - virtual Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - return HandleElementwiseUnary(convert, HloOpcode::kConvert, operand); + virtual Status HandleConvert(HloInstruction* convert) { + return HandleElementwiseUnary(convert, HloOpcode::kConvert); } - virtual Status HandleCopy(HloInstruction* copy, HloInstruction* operand) { - return HandleElementwiseUnary(copy, HloOpcode::kCopy, operand); + virtual Status HandleCopy(HloInstruction* copy) { + return HandleElementwiseUnary(copy, HloOpcode::kCopy); } virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(multiply, HloOpcode::kMultiply, lhs, rhs); + return HandleElementwiseBinary(multiply, HloOpcode::kMultiply); } virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) = 0; virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(power, HloOpcode::kPower, lhs, rhs); + return HandleElementwiseBinary(power, HloOpcode::kPower); } virtual Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, @@ -109,64 +103,73 @@ class DfsHloVisitor { virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0; virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(compare, opcode, lhs, rhs); + return HandleElementwiseBinary(compare, opcode); } virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(add, HloOpcode::kAdd, lhs, rhs); + return HandleElementwiseBinary(add, HloOpcode::kAdd); } virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(divide, HloOpcode::kDivide, lhs, rhs); + return HandleElementwiseBinary(divide, HloOpcode::kDivide); } virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(remainder, HloOpcode::kRemainder, lhs, rhs); + return HandleElementwiseBinary(remainder, HloOpcode::kRemainder); } virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(subtract, HloOpcode::kSubtract, lhs, rhs); + return HandleElementwiseBinary(subtract, HloOpcode::kSubtract); } virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { - return HandleElementwiseUnary(abs, HloOpcode::kAbs, operand); + return HandleElementwiseUnary(abs, HloOpcode::kAbs); } virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) { - return HandleElementwiseUnary(sign, HloOpcode::kSign, operand); + return HandleElementwiseUnary(sign, HloOpcode::kSign); } virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) { - return HandleElementwiseUnary(negate, HloOpcode::kNegate, operand); + return HandleElementwiseUnary(negate, HloOpcode::kNegate); } virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) { - return HandleElementwiseUnary(exp, HloOpcode::kExp, operand); + return HandleElementwiseUnary(exp, HloOpcode::kExp); } virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) { - return HandleElementwiseUnary(floor, HloOpcode::kFloor, operand); + return HandleElementwiseUnary(floor, HloOpcode::kFloor); } virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) { - return HandleElementwiseUnary(ceil, HloOpcode::kCeil, operand); + return HandleElementwiseUnary(ceil, HloOpcode::kCeil); } virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) { - return HandleElementwiseUnary(log, HloOpcode::kLog, operand); + return HandleElementwiseUnary(log, HloOpcode::kLog); + } + virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) { + return HandleElementwiseUnary(cos, HloOpcode::kCos); + } + virtual Status HandleSin(HloInstruction* sin, HloInstruction* operand) { + return HandleElementwiseUnary(sin, HloOpcode::kSin); } virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { - return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand); + return HandleElementwiseUnary(tanh, HloOpcode::kTanh); } virtual Status HandleIsFinite(HloInstruction* is_finite, HloInstruction* operand) { - return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite, operand); + return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite); } virtual Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs, - rhs); + return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd); } virtual Status HandleLogicalNot(HloInstruction* logical_not, HloInstruction* operand) { - return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot, operand); + return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot); } virtual Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr, lhs, rhs); + return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr); + } + virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { + return HandleElementwiseUnary(reduce_precision, + HloOpcode::kReducePrecision); } virtual Status HandleInfeed(HloInstruction* infeed) = 0; @@ -225,6 +228,10 @@ class DfsHloVisitor { virtual Status HandleRecv(HloInstruction* recv) = 0; + virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0; + + virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstruction* root) = 0; @@ -237,6 +244,14 @@ class DfsHloVisitor { kVisited, }; + VisitState GetVisitState(const HloInstruction& instruction) { + auto it = visit_state_.find(&instruction); + if (it == visit_state_.end()) { + return kNotVisited; + } + return it->second; + } + // Sets the visitation state of the given instruction as kVisiting. // // Precondition: current state must be kNotVisited. @@ -248,13 +263,19 @@ class DfsHloVisitor { void SetVisited(const HloInstruction& instruction); // Returns whether the state of the given instruction is kVisiting. - bool IsVisiting(const HloInstruction& instruction); + bool IsVisiting(const HloInstruction& instruction) { + return GetVisitState(instruction) == kVisiting; + } // Returns whether the state of the given instruction is kVisited. - bool DidVisit(const HloInstruction& instruction); + bool DidVisit(const HloInstruction& instruction) { + return GetVisitState(instruction) == kVisited; + } // Returns whether the state of the given instruction is kNotVisited. - bool NotVisited(const HloInstruction& instruction); + bool NotVisited(const HloInstruction& instruction) { + return GetVisitState(instruction) == kNotVisited; + } // This method should be overridden by subclasses that wish to run some // operation on an op before its Handle* visitor method is called. @@ -279,7 +300,7 @@ class DfsHloVisitor { private: // Tracks the visitation state of each instruction. Any instructions that are - // not found from the map are considered as VisitState::kNotVisited. + // not found in the map are considered as VisitState::kNotVisited. tensorflow::gtl::FlatMap visit_state_; TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitor); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 6557c3aa8e6b8356887432c6dd91d326603fc1e7..c447165ceccc3d55088cafda24d90fedea9994ae 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -41,15 +41,23 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { // Default action performed on HloInstruction. virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0; - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand) override { + Status HandleElementwiseUnary(HloInstruction* hlo, + HloOpcode opcode) override { return DefaultAction(hlo); } - Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode) override { return DefaultAction(hlo); } + + Status HandleBatchNormTraining(HloInstruction* hlo) override { + return DefaultAction(hlo); + } + + Status HandleBatchNormGrad(HloInstruction* hlo) override { + return DefaultAction(hlo); + } + Status HandleClamp(HloInstruction* clamp, HloInstruction* /*min*/, HloInstruction* /*arg*/, HloInstruction* /*max*/) override { @@ -60,12 +68,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { tensorflow::gtl::ArraySlice /*operands*/) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstruction* convert, - HloInstruction* /*operand*/) override { + Status HandleConvert(HloInstruction* convert) override { return DefaultAction(convert); } - Status HandleCopy(HloInstruction* copy, - HloInstruction* /*operand*/) override { + Status HandleCopy(HloInstruction* copy) override { return DefaultAction(copy); } Status HandleSelect(HloInstruction* select, HloInstruction* /*pred*/, diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index be4aadb6522b8d6ad9d6425df56c1746c3849f11..81092e42d5c841546ad49cbb8cfaf501fa8cd55e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -63,7 +63,7 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); - CHECK(primitive_util::IsIntegralType(from_type)); + CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED); if (from_type == to_type) { return operand_value; } @@ -78,7 +78,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); } - if (primitive_util::IsUnsignedIntegralType(from_type)) { + if (primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED) { return ir_builder_->CreateUIToFP( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); @@ -172,6 +173,14 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, {operand_value->getType()}, ir_builder_); + case HloOpcode::kCos: + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value}, + {operand_value->getType()}, + ir_builder_); + case HloOpcode::kSin: + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {operand_value}, + {operand_value->getType()}, + ir_builder_); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, @@ -381,6 +390,118 @@ StatusOr ElementalIrEmitter::EmitErfcInv( return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } +StatusOr ElementalIrEmitter::EmitReducePrecision( + const HloInstruction* hlo, llvm::Value* x) const { + if (hlo->operand(0)->shape().element_type() != F32) { + return Unimplemented("reduce-precision only implemented for F32"); + } + + // Integer and float types for casting and constant generation. + llvm::Type* float_type = x->getType(); + llvm::IntegerType* int_type = ir_builder_->getInt32Ty(); + + // Cast the input value to an integer for bitwise manipulation. + llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type); + + if (hlo->mantissa_bits() < 23) { + // Last remaining mantissa bit. + const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits()); + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr( + ir_builder_->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + (23 - hlo->mantissa_bits())); + llvm::Value* x_rounding_bias = ir_builder_->CreateAdd( + x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = ir_builder_->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + } + + if (hlo->exponent_bits() < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (hlo->exponent_bits() - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + llvm::Value* x_exponent = ir_builder_->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_overflows = ir_builder_->CreateICmpUGT( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + llvm::Value* x_underflows = ir_builder_->CreateICmpULE( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + + // Compute appropriately-signed values of zero and infinity. + llvm::Value* x_signed_zero = ir_builder_->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_inf = ir_builder_->CreateOr( + x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int); + } + + // Cast the result back to a floating-point type. + llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type); + + // Correct result for NaN inputs. + // + // The exponent handling will "normalize" NaN values to infinities, which is + // undesirable (except in the case with no mantissa bits, in which case it + // is mandatory). This logic also handles cases where mantissa-rounding + // causes a NaN's mantissa to overflow into the exponent bits, which would + // otherwise create an erroneous zero value. + // + // If the fast-math flags are set to assume no NaNs, the comparison is likely + // to be optimized away, so there's no point in even emitting it. + if (!ir_builder_->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x); + + if (hlo->mantissa_bits() > 0) { + result = ir_builder_->CreateSelect(x_is_nan, x, result); + } else { + result = ir_builder_->CreateSelect( + x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); + } + } + return result; +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { @@ -588,20 +709,37 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)}, {param_ir_type}, ir_builder_); auto in_block = ir_builder_->GetInsertBlock(); - auto body_block = in_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_body"); - SetToFirstInsertPoint(body_block, ir_builder_); - auto out_block = body_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_out"); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(), + in_block->getTerminator() == nullptr); + + llvm::BasicBlock* body_block; + llvm::BasicBlock* out_block; + + if (ir_builder_->GetInsertPoint() == in_block->end()) { + body_block = + llvm_ir::CreateBasicBlock(nullptr, "rng_body", ir_builder_); + out_block = + llvm_ir::CreateBasicBlock(nullptr, "rng_out", ir_builder_); + llvm::BranchInst::Create(body_block, in_block); + } else { + body_block = in_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_body"); + out_block = body_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_out"); + body_block->getTerminator()->eraseFromParent(); + } + SetToFirstInsertPoint(body_block, ir_builder_); auto random = ir_builder_->CreateAnd( ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), leading_zeros)); - llvm::ReplaceInstWithInst( - body_block->getTerminator(), - llvm::BranchInst::Create(out_block, body_block, - ir_builder_->CreateICmpULT(random, r))); + llvm::BranchInst::Create(out_block, body_block, + ir_builder_->CreateICmpULT(random, r), + body_block); SetToFirstInsertPoint(out_block, ir_builder_); return ir_builder_->CreateAdd( p, ir_builder_->CreateSelect( @@ -647,12 +785,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCeil: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kNegate: case HloOpcode::kSign: + case HloOpcode::kSin: case HloOpcode::kTanh: case HloOpcode::kLogicalNot: return [this, hlo, &operand_to_generator]( @@ -720,6 +860,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, 2))); return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); }; + case HloOpcode::kReducePrecision: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + return EmitReducePrecision(hlo, operand_value); + }; case HloOpcode::kConcatenate: return [this, hlo, &operand_to_generator]( const IrArray::Index target_index) -> StatusOr { @@ -805,23 +953,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - IrArray::Index sliced_index(index.size()); - for (int i = 0; i < index.size(); ++i) { - int64 stride = hlo->slice_stride(i); - if (stride != 1) { - sliced_index[i] = ir_builder_->CreateAdd( - ir_builder_->CreateMul( - index[i], llvm::ConstantInt::get(index[i]->getType(), - stride)), - llvm::ConstantInt::get(index[i]->getType(), - hlo->slice_starts(i))); - } else { - sliced_index[i] = ir_builder_->CreateAdd( - index[i], - llvm::ConstantInt::get(index[i]->getType(), - hlo->slice_starts(i))); - } - } + IrArray::Index sliced_index = index.SourceIndexOfSlice( + /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(), + /*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_); return operand_to_generator.at(hlo->operand(0))(sliced_index); }; case HloOpcode::kDynamicSlice: diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 2576d3823e06ed3050554b38766dbd6c6a48ca5c..bb9117ca61e3b6ccb7f1fcecb62b0be5f984e6d1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -84,6 +84,9 @@ class ElementalIrEmitter { virtual StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, + llvm::Value* x) const; + // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its // `operand_no`-th operand. diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 3a9f8dc79ee0589f27fe5aabf9592a73f34c4a0e..cbc02b84627d992179e88c107840a14c104c01c8 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -15,45 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/regexp.h" namespace xla { -/* static */ void Executable::DumpExecutedHlo( - const HloModule& module, const string& label, - const HloExecutionProfile* profile) { - VLOG(2) << "module name = " << module.name(); - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - string generate_hlo_graph_regex; - if (!flags->xla_generate_hlo_graph.empty()) { - generate_hlo_graph_regex = flags->xla_generate_hlo_graph; - } else { - generate_hlo_graph_regex = - module.config().debug_options().xla_generate_hlo_graph(); - } - if (!generate_hlo_graph_regex.empty() && - RE2::PartialMatch(module.name(), generate_hlo_graph_regex)) { - hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, - flags->xla_hlo_graph_addresses, - flags->xla_hlo_graph_layout, profile); - } - if (!flags->xla_log_hlo_text.empty() && - RE2::PartialMatch(module.name(), flags->xla_log_hlo_text)) { - LOG(INFO) << "HLO for module " << module.name(); - LOG(INFO) << "Label: " << label; - XLA_LOG_LINES(2, module.ToString()); - } - if (!flags->xla_dump_hlo_text_to.empty()) { - hlo_graph_dumper::DumpText(module, label, flags->xla_dump_hlo_text_to); - } -} - StatusOr> Executable::ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, @@ -87,8 +57,8 @@ Executable::ExecuteOnStreams( Status Executable::DumpSessionModule() { TF_RET_CHECK(dumping()); - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - const string& directory_path = flags->xla_dump_executions_to; + const string& directory_path = + module_config().debug_options().xla_dump_executions_to(); VersionedComputationHandle versioned_handle = entry_computation_handle(); // This filename does not include the version number because the computation // is only ever executed at one version. diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 291916cd9f7acb0c136dc0834b28f57a83736ec6..5388c9efa4b795b3861586030c407b6864b9382e 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -19,11 +19,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/session.pb.h" @@ -49,10 +50,6 @@ class Executable { shape_size_function_(std::move(shape_size_function)) {} virtual ~Executable() {} - // Dumps the executed HLO according to service-associated flags. - static void DumpExecutedHlo(const HloModule& module, const string& label, - const HloExecutionProfile* profile); - // Enqueues the compilation result on the provided stream, passing the given // arguments. This call is blocking and returns after the execution is done. // @@ -110,6 +107,14 @@ class Executable { return execution_profile_; } + // Returns Status::ok() if the two executables are equal to each other. + // + // An error status is returned otherwise. + virtual const Status EqualOrFail(const Executable& executable) { + return Unimplemented( + "Equality test on this executable is not implemented."); + } + // Returns whether this executable was compiled with HLO profilings support // enabled. If not, the caller should not expect an hlo_execution_profile // passed to ExecuteOnStream above to be populated during execution. @@ -191,10 +196,11 @@ StatusOr Executable::ExecuteOnStreamWrapper( // If the profiling flag isn't enabled, we pass nullptr as the profile to // indicate profiling is not requested. HloExecutionProfile hlo_execution_profile; - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); HloExecutionProfile* profile_ptr = - flags->xla_hlo_profile && hlo_profiling_enabled() ? &hlo_execution_profile - : nullptr; + module_config().debug_options().xla_hlo_profile() && + hlo_profiling_enabled() + ? &hlo_execution_profile + : nullptr; auto return_value = ExecuteOnStream(run_options, arguments, profile_ptr); @@ -240,7 +246,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( } } } - DumpExecutedHlo(module(), "Service::Execute", profile_ptr); + hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", + profile_ptr); } return return_value; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index bb4712c86f6d649a9ec8f1450d90735de9ec43c3..12a6794ac177deb54dd66822a5f830ff213c7b40 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -139,7 +139,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); @@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, param0, false_constant)); @@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { { HloComputation::Builder builder(TestName() + ".entry"); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateWhile( ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, false_constant)); @@ -182,7 +182,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); @@ -211,7 +211,7 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(7, module->computations().size()); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index eb8b93330fbc7b786c66a07f8009b4676358421b..69195c45ed33bbb689a0633471686a03bb6d2654 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -82,13 +82,12 @@ Status GenericTransferManager::TransferLiteralFromDevice( } *literal->mutable_shape() = device_shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal); + literal->Reserve(ShapeUtil::ElementsIn(device_shape)); TF_RETURN_IF_ERROR(TransferBufferFromDevice( executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape), - /*destination=*/LiteralUtil::MutableInternalData(literal))); + /*destination=*/literal->MutableInternalData())); if (!ShapeUtil::Equal(literal_shape, device_shape)) { - literal->Swap( - LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); + literal->Swap(literal->Relayout(literal_shape.layout()).get()); } TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); return Status::OK(); @@ -152,27 +151,34 @@ Status GenericTransferManager::TransferLiteralToDevice( tuple_elements_on_device.data(), destination); } - return TransferBufferToDevice( - executor, /*size=*/GetByteSizeRequirement(shape), - /*source=*/LiteralUtil::InternalData(literal), destination); + return TransferBufferToDevice(executor, + /*size=*/GetByteSizeRequirement(shape), + /*source=*/literal.InternalData(), destination); } Status GenericTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const Literal& literal) { - return Unimplemented("Infeed is not supported on GPU (b/30467474)"); + return Unimplemented("Generic transfer to Infeed"); +} + +Status GenericTransferManager::TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) { + return Unimplemented("Generic transfer to Infeed"); } Status GenericTransferManager::TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) { - return Unimplemented("Outfeed is not supported on CPU/GPU (b/30467474)"); + return Unimplemented( + "Outfeed is not supported on this platform (b/30467474)"); } Status GenericTransferManager::ResetDevices( tensorflow::gtl::ArraySlice - executors) { + /*executors*/) { return Unimplemented( - "Device reset is not yet supported on CPU and GPU (b/30481585)"); + "Device reset is not yet supported on this platform (b/30481585)"); } int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 2fbdb94f06f1b12763571dc2aa9b0d770f420406..48c061d28e5967f903e9ea665fdaeb02fab7e02e 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -54,6 +54,8 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; Status TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 869869341179822aa8d9e9675211be92f733077d..cdd7c8187c94231c3889dc9135030268a861b3da 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -68,8 +68,8 @@ cc_library( deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:stream_assignment_flags", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", ], ) @@ -253,7 +253,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:convolution_thunk_flags", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", @@ -267,7 +266,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", - "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep ], ) @@ -376,7 +375,6 @@ cc_test( ":fusion_merger", ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -418,7 +416,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:gpu_compiler_flags", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -435,6 +432,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto_util", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", @@ -500,8 +498,10 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_ordering", + "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/compiler/xla/service:hlo_scheduling", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 9a0b14eb7332358d0e68e6a40b47c94b88666eb6..20e0d8eb785daa07b3fcc5339efe950aac0dacad 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -287,10 +286,7 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( const ConvolutionDescriptor& convolution_descriptor, const BufferAllocations& buffer_allocations, se::Stream* stream) { // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - legacy_flags::ConvolutionThunkFlags* flags = - legacy_flags::GetConvolutionThunkFlags(); - if (flags->xla_gpu_autotune_convolution_algorithm && - best_algorithm_.algorithm() == se::dnn::kDefaultAlgorithm) { + if (best_algorithm_.algorithm() == se::dnn::kDefaultAlgorithm) { // Auto-tuning either is disabled or only happens in the first run of this // function. VLOG(2) << "Profiling for best convolution algorithm used for " diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index aaf72935e61ee8b8da7df410ba3aaed63800cfd9..91d6df299da2686d6d836445d391c4b0eaf4ed00 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -81,9 +81,8 @@ class ConvolutionThunk : public Thunk { ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; - // Does the convolution for the thunk on "stream". If the - // xla_gpu_autotune_convolution_algorithm is turned on, auto-tuning happens on - // the first run of this function. + // Does the convolution for the thunk on "stream". Auto-tuning happens on the + // first run of this function. tensorflow::Status ExecuteOnStream( const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 2987c8913d7cdd93d57bfcca40d6c56ae4dd30f0..c03213ab6d61df56dc3c826aac90271075e6db4a 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -55,7 +55,7 @@ using tensorflow::strings::StrAppend; // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { return operand->opcode() == HloOpcode::kConstant && - LiteralUtil::IsAllFloat(operand->literal(), value); + operand->literal().IsAllFloat(value); } GpuElementalIrEmitter::GpuElementalIrEmitter( @@ -211,6 +211,12 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( case HloOpcode::kLog: return EmitLibdeviceMathCall("__nv_log", {operand_value}, {input_type}, output_type); + case HloOpcode::kCos: + return EmitLibdeviceMathCall("__nv_cos", {operand_value}, {input_type}, + output_type); + case HloOpcode::kSin: + return EmitLibdeviceMathCall("__nv_sin", {operand_value}, {input_type}, + output_type); case HloOpcode::kTanh: return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index afb78b8300b457ba9384bd66f789d333630b51e4..a9ef204b46facafabcf16d1d38d69c14e6aab497 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -98,7 +98,13 @@ double CalculateFlopsToBytesRatio(HloInstruction* fusion) { // Calculate total bytes transferred in/out. double bytes = CalculateBytesReadByFusionInstruction(fusion); // Add bytes written to root instructions buffer. - bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); + if (fusion->IsMultiOutputFusion()) { + for (auto& operand : fusion->fused_expression_root()->operands()) { + bytes += ShapeUtil::ByteSizeOf(operand->shape()); + } + } else { + bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); + } // Calculate flops for all fused instructions. Use a null shape size function // because we don't care about bytes accessed by the ops. HloCostAnalysis analysis([](const Shape& shape) { return 0; }); @@ -112,8 +118,15 @@ double CalculateFlopsToBytesRatio(HloInstruction* fusion) { double GetCurrentBytesTransferred(HloInstruction* fusion) { CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); const double bytes_read = CalculateBytesReadByFusionInstruction(fusion); - const double bytes_written = - ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); + double bytes_written = 0; + if (fusion->IsMultiOutputFusion()) { + for (auto& operand : fusion->fused_expression_root()->operands()) { + bytes_written += ShapeUtil::ByteSizeOf(operand->shape()); + } + } else { + bytes_written = + ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); + } // Current bytes transferred (ignoring non 'fusion' user operands) is bytes // read and written by 'fusion', plus reads of size 'bytes_written' for each // user. @@ -198,6 +211,12 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { ++num_fail_not_loop_fusion_; return Status::OK(); } + + // Skip multiple output fusion. It's not yet supported. + if (fusion->IsMultiOutputFusion()) { + ++num_fail_not_loop_fusion_; + return Status::OK(); + } // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. @@ -274,12 +293,19 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { StatusOr FusionMerger::Run(HloModule* module) { bool changed = false; VLOG(2) << "FusionMerger for module: " << module->name(); + std::vector computations; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + computations.push_back(computation.get()); + } + for (auto& computation : computations) { VLOG(1) << "Before running FusionInstructionMerger for computation: " << computation->name(); XLA_VLOG_LINES(3, computation->ToString()); - FusionInstructionMerger fusion_merger(computation.get()); + FusionInstructionMerger fusion_merger(computation); TF_RETURN_IF_ERROR(fusion_merger.Run()); changed |= fusion_merger.changed(); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 8afc32dea97ea00442d2f094c8d6de0b510482fd..242c32936d31d0cb578825cade5f35979077a44e 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -59,7 +59,7 @@ class FusionMergerTest : public HloTestBase { // Create const vector of ones to be used in element-wise computations. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create simple fusable computation for tuple element 0 (wont get merged). auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -138,7 +138,7 @@ class FusionMergerTest : public HloTestBase { // Create two sub-computations, both of which are users of 'mul0'. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // First sub-computation: out0 = Mul(Add(mul0, one_vec), one_vec) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -209,7 +209,7 @@ class FusionMergerTest : public HloTestBase { // Create two fusable sub-computations which are dependent on shared // computation 'reduce_out'. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // First sub-computation: out0 = Mul(Add(reduce_out, one_vec), one_vec) auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 86137a569f9b199782462582ba11683ff9884d7b..031ecbd3aedfb3e531d8e10c9a7b381bb97037e3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -23,7 +23,6 @@ limitations under the License. #include "external/llvm/include/llvm/IR/DiagnosticPrinter.h" #include "external/llvm/include/llvm/IR/LLVMContext.h" #include "external/llvm/include/llvm/IR/Module.h" -#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -57,6 +56,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -82,7 +82,7 @@ const char* kTargetTriple = "nvptx64-nvidia-cuda"; // The data layout of the emitted module. Copied from computeDataLayout in // NVPTXTargetMachine.cpp. -const char* kDataLayout = "e-i64:64-v16:16-v32:32-n16:32:64"; +const char* kDataLayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"; // Any address of a variable residing in global memory or returned by one of the // memory allocation routines from the driver or runtime API is always aligned @@ -95,11 +95,9 @@ constexpr int64 kMemoryAlignment = 256; // called in GpuCompiler's constructor, so can't return an error. But // GpuCompiler::Compile will return an error when the wanted libdevice file // doesn't exist in the folder this function returns. -string GetLibdeviceDir() { +string GetLibdeviceDir(const HloModuleConfig& config) { std::vector potential_libdevice_dirs; - // Flag xla_cuda_data_dir specified by the user. - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - const string datadir = flags->xla_cuda_data_dir; + const string datadir = config.debug_options().xla_gpu_cuda_data_dir(); if (!datadir.empty()) { potential_libdevice_dirs.push_back(datadir); } @@ -122,14 +120,16 @@ string GetLibdeviceDir() { // Runs optimization passes on the given HLO module. tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - const Compiler::HloDumper& dump_hlo, const se::DeviceDescription& device_desc) { { - HloPassPipeline pipeline("optimization", dump_hlo); + HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); + ReducePrecisionInsertion::AddPasses( + &pipeline, hlo_module->config().debug_options(), + HloReducePrecisionOptions::BEFORE_OP_FUSION); { - auto& pass = pipeline.AddPass>( - "simplification", dump_hlo); + auto& pass = + pipeline.AddPass>("simplification"); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -149,24 +149,37 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { - HloPassFix fusion("fusion", dump_hlo); + HloPassFix fusion("fusion"); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); - return fusion.Run(hlo_module).status(); + TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); + + HloPassPipeline reduce_pipeline("reduce-precision"); + ReducePrecisionInsertion::AddPasses( + &reduce_pipeline, hlo_module->config().debug_options(), + HloReducePrecisionOptions::AFTER_OP_FUSION); + StatusOr reduce_result = reduce_pipeline.Run(hlo_module); + TF_RETURN_IF_ERROR(reduce_result.status()); + + if (reduce_result.ValueOrDie()) { + // Do another fusion pass, with the expectation that we may be able to + // fuse the new ReducePrecision operations. + TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); + } } + return tensorflow::Status::OK(); } // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting( - const Compiler::HloDumper& dump_hlo, HloModule* hlo_module) { +tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. - HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo); + HloPassPipeline pipeline("GPU-ir-emit-prepare"); pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass( @@ -230,17 +243,15 @@ void DumpPtxasInfo(const string& ptx) { } // namespace GpuCompiler::GpuCompiler() - : libdevice_dir_(GetLibdeviceDir()), - pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} + : pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} StatusOr> GpuCompiler::Compile( - std::unique_ptr module, HloDumper dump_hlo, - se::StreamExecutor* stream_exec) { + std::unique_ptr module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), dump_hlo, - stream_exec->GetDeviceDescription())); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, module.get())); + TF_RETURN_IF_ERROR( + OptimizeHloModule(module.get(), stream_exec->GetDeviceDescription())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); llvm::LLVMContext llvm_context; std::string buffer; @@ -271,13 +282,16 @@ StatusOr> GpuCompiler::Compile( TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), kMemoryAlignment)); + BufferSizeBytesFunction(), [](LogicalBuffer::Color) { + return kMemoryAlignment; + })); - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - if (!flags->xla_gpu_dump_debug_json_to.empty()) { + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_gpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), @@ -292,7 +306,9 @@ StatusOr> GpuCompiler::Compile( entry_computation->root_instruction()->Accept(&ir_emitter)); string ir_module_string_before_opt; - if (VLOG_IS_ON(2) || flags->xla_gpu_embed_ir) { + const bool embed_ir_in_executable = + module->config().debug_options().xla_embed_ir_in_executable(); + if (VLOG_IS_ON(2) || embed_ir_in_executable) { ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); VLOG(2) << "LLVM module before optimizations:"; XLA_VLOG_LINES(2, ir_module_string_before_opt); @@ -313,6 +329,10 @@ StatusOr> GpuCompiler::Compile( cc_major = 2; cc_minor = 0; } + if (libdevice_dir_.empty()) { + // Compute libdevice_dir_ just once and cache it in this member. + libdevice_dir_ = GetLibdeviceDir(module->config()); + } TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, module->config(), libdevice_dir_)); @@ -333,7 +353,7 @@ StatusOr> GpuCompiler::Compile( auto* gpu_executable = new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(module), std::move(buffer_assignment), ShapeSizeBytesFunction()); - if (flags->xla_gpu_embed_ir) { + if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); } @@ -341,16 +361,15 @@ StatusOr> GpuCompiler::Compile( } StatusOr>> GpuCompiler::Compile( - std::vector> modules, HloDumper dump_hlos, + std::vector> modules, std::vector stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on GPU."); } StatusOr>> -GpuCompiler::CompileAheadOfTime( - std::vector> module, - HloDumper dump_hlo, const AotCompilationOptions& options) { +GpuCompiler::CompileAheadOfTime(std::vector> module, + const AotCompilationOptions& options) { return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index da52f5ab1f8e5bf8c2fa3c33948ccf8a0f647f7a..b87555b931f1d73de8bcaf84aea80305c9d585bf 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -41,17 +41,16 @@ class GpuCompiler : public Compiler { ~GpuCompiler() override {} StatusOr> Compile( - std::unique_ptr module, HloDumper dump_hlo, + std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> modules, HloDumper dump_hlo, + std::vector> modules, std::vector stream_exec) override; StatusOr>> - CompileAheadOfTime( - std::vector> module, - HloDumper dump_hlo, AotCompilationOptions const& options) override; + CompileAheadOfTime(std::vector> module, + AotCompilationOptions const& options) override; perftools::gputools::Platform::Id PlatformId() const override; @@ -65,7 +64,7 @@ class GpuCompiler : public Compiler { private: // The parent directory of libdevice IR libraries. - const string libdevice_dir_; + string libdevice_dir_; // The list of PTX strings generated by this GpuCompiler. We let GpuCompiler // to own them because they need to be alive across the life span of the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index e1a55118fc7a962cbc77b8214f01451e6f155ca0..8558e150e06e31fc36c60cf8564f0d22cba020e8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -80,6 +80,11 @@ class GpuExecutable : public Executable { tensorflow::gtl::ArraySlice arguments) override; + const Status EqualOrFail(const Executable& executable) { + // TODO(b/62952745) Implement equality test on GPU executable. + return Unimplemented("Equality test on GPU executable is not implemented."); + } + private: // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index d16a1d4ee5be00e685fc181f19c1a3cfda253f6a..81e905a06665436875b17991a8635e7bb31600de 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -67,38 +69,38 @@ GpuHloOrdering::GpuHloOrdering( // waits for its operands before executing. // // The predecessor map is built incrementally, in thunk launch order. We - // record the instructions already visited per stream in - // 'instructions_per_stream'. This lets us quickly determine the same-stream - // predecessors of each instruction. To capture cross-stream dependency edges, - // we use the predecessor map to insert each operand as well as its transitive - // closure of dependencies. - - // Compute the set of all instructions we will want to set reachability on - auto predecessor_map = MakeUnique( + // record the most-recently seen instructions per stream in + // 'last_instruction_per_stream'. This lets us quickly determine the + // same-stream predecessors of each instruction. + + // Compute the set of all instructions we will want to set reachability on. + auto predecessor_map = MakeUnique( module->entry_computation()->MakeInstructionPostOrder()); - std::vector> instructions_per_stream( - stream_assignment.StreamCount()); + // The most recently visited instruction per stream. + std::vector last_instruction_per_stream( + stream_assignment.StreamCount(), nullptr); for (const HloInstruction* hlo : thunk_launch_order) { + predecessor_map->SetReachable(hlo, hlo); if (stream_assignment.HasStreamAssigned(*hlo)) { + // Gather all instruction which are immediate predecessors of 'hlo' in the + // reachability graph. + std::vector immediate_preds; + immediate_preds.insert(immediate_preds.end(), hlo->operands().begin(), + hlo->operands().end()); + immediate_preds.insert(immediate_preds.end(), + hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + // All ops already queued on the same instruction stream, and their - // transitive predecessors, are predecessors. Since the relation is - // transitive, we just set the transitive closure of the previous op. + // transitive predecessors, are predecessors. const int stream_no = stream_assignment.StreamNumberForHlo(*hlo); - std::vector* instructions = - &instructions_per_stream[stream_no]; - if (!instructions->empty()) { - const HloInstruction* back = instructions->back(); - predecessor_map->SetReachableAndTransitiveClosure(hlo, back); - } - // All operands and their transitive predecessors are predecessors. Each - // operand must already exist in 'predecessor_map', since we're iterating - // in thunk launch order. - for (const HloInstruction* operand : hlo->operands()) { - predecessor_map->SetReachableAndTransitiveClosure(hlo, operand); + if (last_instruction_per_stream[stream_no] != nullptr) { + immediate_preds.push_back(last_instruction_per_stream[stream_no]); } - instructions->push_back(hlo); + predecessor_map->SetReachabilityToUnion(immediate_preds, hlo); + last_instruction_per_stream[stream_no] = hlo; } else { // Only parameters and constants don't have an assigned stream, since they // don't require a thunk. These ops don't have any predecessors. @@ -107,21 +109,21 @@ GpuHloOrdering::GpuHloOrdering( CHECK_EQ(hlo->operand_count(), 0); } } - strict_predecessors_.emplace(module->entry_computation(), - std::move(predecessor_map)); + predecessors_.emplace(module->entry_computation(), + std::move(predecessor_map)); - // The ordering of instructions in subcomputations is based solely on data - // dependencies. I.e. the strict predecessors of each subcomputation - // instruction is its transitive operands. + // The ordering of instructions in subcomputations is based solely on control + // and data dependencies. // // TODO(toddw): Each subcomputation is actually emitted as a function in DFS // postorder, so we can do better and establish the total order here. We don't // do that yet since it's hard to ensure that the order here is the order used // by IrEmitterNested. And mismatched ordering bugs would be hard to find. for (auto& computation : module->computations()) { - if (computation.get() != module->entry_computation()) { - strict_predecessors_.emplace(computation.get(), - computation->ComputeTransitiveOperands()); + if (computation.get() != module->entry_computation() && + !computation->IsFusionComputation()) { + predecessors_.emplace(computation.get(), + computation->ComputeReachability()); } } } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h index 773973010a46bb4a2af1f536c43201ba8c0be5d8..1ce7a48ac8fcbbad0b3697845681582fe806b322 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h @@ -19,9 +19,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 1a61eec353740202065c1ce98e8c91274facfd19..a04214930dfc95b82ca4c702d12648381a4c8135 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -86,23 +86,35 @@ void HloToIrBindings::EmitBasePointersForHlos( continue; } - // A non-IO HLO with a buffer is bound to - // (1) an alloca if it is thread-local, or - // (2) an internal pointer in temp_buffer_base according to its offset. - const BufferAllocation::Slice slice = - buffer_assignment_->GetUniqueTopLevelSlice(non_io_hlo) - .ConsumeValueOrDie(); - if (slice.allocation()->is_thread_local()) { - llvm::Type* pointee_type = - llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_); - BindHloToIrValue(*non_io_hlo, ir_builder_->CreateAlloca(pointee_type)); - } else { - const int64 offset = slice.offset(); - CHECK_NE(nullptr, temp_buffer_base_); - BindHloToIrValue(*non_io_hlo, - ir_builder_->CreateInBoundsGEP( - temp_buffer_base_, ir_builder_->getInt64(offset))); - } + ShapeUtil::ForEachSubshape( + non_io_hlo->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + // A non-IO HLO with a buffer is bound to + // (1) an alloca if it is thread-local, or + // (2) an internal pointer in temp_buffer_base according to its + // offset. + auto slice_result = + buffer_assignment_->GetUniqueSlice(non_io_hlo, index); + if (!slice_result.ok()) { + return; + } + const BufferAllocation::Slice slice = + slice_result.ConsumeValueOrDie(); + if (slice.allocation()->is_thread_local()) { + llvm::Type* pointee_type = + llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_); + BindHloToIrValue(*non_io_hlo, + ir_builder_->CreateAlloca(pointee_type), index); + } else { + const int64 offset = slice.offset(); + CHECK_NE(nullptr, temp_buffer_base_); + BindHloToIrValue( + *non_io_hlo, + ir_builder_->CreateInBoundsGEP(temp_buffer_base_, + ir_builder_->getInt64(offset)), + index); + } + }); } } @@ -112,7 +124,7 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), base_ptr), ir_builder_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, @@ -120,8 +132,10 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, + const ShapeIndex& shape_index, llvm::Value* ir_value) { - llvm::Type* pointee_type = llvm_ir::ShapeToIrType(hlo.shape(), ir_builder_); + llvm::Type* pointee_type = llvm_ir::ShapeToIrType( + ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_builder_); llvm::Type* dest_type = pointee_type->getPointerTo(); llvm::Value* typed_ir_value; @@ -139,13 +153,24 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, } void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, - llvm::Value* ir_value) { + llvm::Value* ir_value, + const ShapeIndex& shape_index) { VLOG(2) << "Binding " << hlo.ToString(); - InsertOrDie(&base_ptrs_, &hlo, GetTypedIrValue(hlo, ir_value)); + + const Shape& hlo_shape = hlo.shape(); + llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value); + + if (!BoundToIrValue(hlo)) { + // Set the root of ShapeTree first before assigning the element ir value. + InsertOrDie(&base_ptrs_, &hlo, ShapeTree(hlo_shape, nullptr)); + } + *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value; } -llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo) { - llvm_ir::IrArray ir_array(GetBasePointer(hlo), hlo.shape()); +llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, + const ShapeIndex& shape_index) { + llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), + ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); return ir_array; } @@ -154,7 +179,7 @@ void HloToIrBindings::UnbindAllLocalIrValues() { std::vector hlos_to_unbind; for (auto& key_value : base_ptrs_) { if (!llvm::isa( - key_value.second->stripPointerCasts())) { + (key_value.second.element({}))->stripPointerCasts())) { hlos_to_unbind.push_back(key_value.first); } } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 5be2150801fbd2a3a624d9c87513d5cee7288bbd..2c59886e9ae410b6a6a1dd9973c75c061c8db808 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -48,7 +48,8 @@ class HloToIrBindings { tensorflow::gtl::ArraySlice non_io_hlos); // Rebinds the given HLO to the LLVM IR value that represent its address. - void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value); + void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, + const ShapeIndex& shape_index = {}); // Unbinds all IR values that's defined in an LLVM function, e.g., function // arguments and stack variables. Global variables will be kept in bindings_. @@ -64,15 +65,18 @@ class HloToIrBindings { llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } - // A helper method that returns the base pointer of the IrArray for "inst". - llvm::Value* GetBasePointer(const HloInstruction& hlo) const { + // A helper method that returns the base pointer of the IrArray containing the + // output of "inst".at the given ShapeIndex. + llvm::Value* GetBasePointer(const HloInstruction& hlo, + const ShapeIndex& shape_index = {}) const { auto it = base_ptrs_.find(&hlo); CHECK(it != base_ptrs_.end()); - return it->second; + return it->second.element(shape_index); } // Return the underlying IrArray of the output of the given instruction. - llvm_ir::IrArray GetIrArray(const HloInstruction& hlo); + llvm_ir::IrArray GetIrArray(const HloInstruction& hlo, + const ShapeIndex& shape_index = {}); private: // Emits IR to resolve (possibly) recursive GetTupleElement instructions. @@ -81,6 +85,7 @@ class HloToIrBindings { // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape. llvm::Value* GetTypedIrValue(const HloInstruction& hlo, + const ShapeIndex& shape_index, llvm::Value* ir_value); const BufferAssignment* buffer_assignment_; @@ -90,7 +95,10 @@ class HloToIrBindings { llvm::IRBuilder<>* ir_builder_; // Stores the underlying llvm::IrArray for each HloInstruction. - std::unordered_map base_ptrs_; + // For an instruction that generates multiple outputs, the root will be a + // tuple shape. The IrArray for each element output is stored in the subnode + // in the ShapeTree. + std::unordered_map> base_ptrs_; // The address of the memory block that contains all temporary buffers. llvm::Value* temp_buffer_base_; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index 120a3f7fba2101ce64da1e8135fb5f862e603fe4..ee5b447c9cd0b1fde4d3a0943d5d4cb8cc5b3376 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/core/platform/logging.h" namespace se = ::perftools::gputools; @@ -22,23 +24,23 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { -InfeedManager::InfeedManager() - : current_buffer_(nullptr), - host_to_device_executor_(nullptr) {} +InfeedManager::InfeedManager() : host_to_device_executor_(nullptr) {} void InfeedManager::Reset() { tensorflow::mutex_lock l(mu_); - CHECK(!current_buffer_); + CHECK(dequeued_buffer_.empty()); for (auto buffer : enqueued_buffer_) { buffer->Done(); } enqueued_buffer_.clear(); } -void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) { +void InfeedManager::EnqueueBuffers(const std::vector& buffers) { tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffer_.empty(); - enqueued_buffer_.push_back(buffer); + for (gpu::InfeedBuffer* b : buffers) { + enqueued_buffer_.push_back(b); + } if (was_empty) { // This has the potential to suffer from the notified thread // immediately trying and failing to acquire mu_, but seems @@ -53,18 +55,23 @@ InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { while (enqueued_buffer_.empty()) { cv_.wait(l); } - CHECK(!current_buffer_); - current_buffer_ = enqueued_buffer_.front(); + InfeedBuffer* current_buffer = enqueued_buffer_.front(); enqueued_buffer_.pop_front(); - return current_buffer_; + dequeued_buffer_.insert(current_buffer); + return current_buffer; } -void InfeedManager::ReleaseCurrentBuffer(se::DeviceMemoryBase* device_memory) { - tensorflow::mutex_lock l(mu_); - CHECK(current_buffer_); - CHECK(device_memory->IsSameAs(*current_buffer_->device_memory())); - current_buffer_->Done(); - current_buffer_ = nullptr; +void InfeedManager::ReleaseBuffers(const std::vector& buffers) { + { + tensorflow::mutex_lock l(mu_); + for (gpu::InfeedBuffer* b : buffers) { + CHECK(ContainsKey(dequeued_buffer_, b)); + dequeued_buffer_.erase(b); + } + } + for (gpu::InfeedBuffer* b : buffers) { + b->Done(); + } } se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index 50d0ce340f3d85c2c46f111dba3e316ff0f4df1a..73d5a5ce35497f156a181371bfb97fc37a8eb09e 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -81,25 +82,19 @@ class InfeedManager { // condition is to call Reset when no computation is taking place. void Reset(); - // Adds buffer to the infeed queue. buffer->Done will be called when - // the buffer will no longer be accessed by the InfeedManager, - // either as a result of a call to Reset or because the runtime has - // dequeued and used the buffer. - void EnqueueBuffer(InfeedBuffer* buffer); + // Adds a set of buffers to the infeed queue atomically. buffer->Done + // will be called when the buffer will no longer be accessed by the + // InfeedManager, either as a result of a call to Reset or because the + // runtime has dequeued and used the buffer. + void EnqueueBuffers(const std::vector& buffers); // Blocks until the infeed queue is non-empty, then returns the - // buffer at the head of the queue. Sets the current buffer to be - // the returned buffer. It is an error to call BlockingDequeueBuffer - // if there is an unreleased current buffer, i.e., - // ReleaseCurrentBuffer must be called between calls to - // BlockingDequeueBuffer. + // buffer at the head of the queue. Adds the current buffer to the + // to-be released set. InfeedBuffer* BlockingDequeueBuffer(); - // Releases the current buffer, which is the last buffer returned by - // BlockingDequeueBuffer and not yet released. device_memory must - // match that of the current buffer. - void ReleaseCurrentBuffer( - perftools::gputools::DeviceMemoryBase* device_memory); + // Releases a set of buffers from the to-be released set. + void ReleaseBuffers(const std::vector& buffers); // Returns a cached stream associated with an executor. Allocates a // new stream on the first invocation. On subsequent invocations, if @@ -109,18 +104,25 @@ class InfeedManager { perftools::gputools::StreamExecutor* executor); private: + // TODO(b/30467474): Revisit if this mutex becomes a point of + // contention. tensorflow::mutex mu_; + // Condition variable that is signaled every time a buffer is // enqueued to an empty queue. tensorflow::condition_variable cv_; + // InfeedBuffer* queue contents are not owned, but buffer->Done must // be called when the buffer is no longer needed by the runtime. std::deque enqueued_buffer_; - // If non-NULL, the buffer that is currently being processed by the + + // Buffers that are dequeued and currently being processed by the // runtime. Not owned. - InfeedBuffer* current_buffer_; + tensorflow::gtl::FlatSet dequeued_buffer_; + // Cached host to device stream for queuing infeed data. std::unique_ptr host_to_device_stream_; + // Executor that the host_to_device_stream belongs to. Not owned. perftools::gputools::StreamExecutor* host_to_device_executor_; }; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 6f144c7273e69beedeb143c395ce37414ce99139..e33e904692ca5ad41e17d2e165dbb40b6bd4aa33 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -21,31 +21,59 @@ limitations under the License. namespace xla { namespace gpu { -InfeedThunk::InfeedThunk(const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction) +InfeedThunk::InfeedThunk( + tensorflow::gtl::ArraySlice tuple_element_buffers, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* hlo_instruction) : Thunk(Kind::kInfeed, hlo_instruction), - destination_buffer_(destination_buffer), - mem_size_(mem_size) {} + tuple_element_buffers_(tuple_element_buffers.begin(), + tuple_element_buffers.end()), + destination_buffer_(destination_buffer) {} tensorflow::Status InfeedThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) { VLOG(2) << "Infeeding to GPU "; - perftools::gputools::DeviceMemoryBase destination_data = + + perftools::gputools::DeviceMemoryBase destination_address = buffer_allocations.GetDeviceAddress(destination_buffer_); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - CHECK_EQ(buffer->length(), mem_size_); - stream->ThenMemcpy(&destination_data, *(buffer->device_memory()), - buffer->length()); + std::vector infeed_buffers; + if (ShapeUtil::IsTuple(hlo_instruction()->shape())) { + CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape())); + // Transfer the tuple elements first. + std::vector tuple_element_addresses; + for (BufferAllocation::Slice tuple_element_buffer : + tuple_element_buffers_) { + perftools::gputools::DeviceMemoryBase tuple_element_address = + buffer_allocations.GetDeviceAddress(tuple_element_buffer); + + InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); + infeed_buffers.push_back(buffer); + stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()), + buffer->length()); + tuple_element_addresses.push_back(tuple_element_address.opaque()); + } + // Transfer the tuple outer buffer. + auto host_size = tuple_element_addresses.size() * sizeof(void*); + stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(), + host_size); + } else { + InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); + infeed_buffers.push_back(buffer); + stream->ThenMemcpy(&destination_address, *(buffer->device_memory()), + buffer->length()); + } + if (!stream->BlockHostUntilDone()) { return InternalError("Failed to complete data transfer on stream %p", stream); } - // Since Infeeds are totally ordered, no other infeed should sneak - // in and we should be able to release the same buffer we dequeued. - infeed_manager->ReleaseCurrentBuffer(buffer->device_memory()); + + infeed_manager->ReleaseBuffers(infeed_buffers); + + VLOG(2) << "Infeeding to GPU complete"; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 0a808186c212660e4be3905456d29cb2fed0f511..371d71f9dbdd21cb5f36cc3108c8f398a4a91c29 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -35,8 +35,10 @@ class InfeedThunk : public Thunk { // infeed queue to the device buffer // `destination_buffer`. `mem_size` is the size of the data in // bytes. - InfeedThunk(const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction); + InfeedThunk(tensorflow::gtl::ArraySlice + tuple_element_buffers, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* hlo_instruction); InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; @@ -46,8 +48,8 @@ class InfeedThunk : public Thunk { perftools::gputools::Stream* stream) override; private: + const std::vector tuple_element_buffers_; const BufferAllocation::Slice destination_buffer_; - const uint64 mem_size_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 7d5b6ed5cfabcd429cc25f63b8fa14e2e20e387f..80f91e5daed30567ff66476ff9066dc36b01ee3c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -202,18 +202,22 @@ bool IrEmitter::MaybeEmitSpecialAtomicOperation( // NVPTX supports atomicMax and atomicMin on only integer types. if (root_opcode == HloOpcode::kMaximum && primitive_util::IsIntegralType(element_type)) { - // min(integral, integral) - ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Max, output_address, - source, + // max(integral, integral) + auto opcode = primitive_util::IsSignedIntegralType(element_type) + ? llvm::AtomicRMWInst::Max + : llvm::AtomicRMWInst::UMax; + ir_builder_.CreateAtomicRMW(opcode, output_address, source, llvm::AtomicOrdering::SequentiallyConsistent); return true; } if (root_opcode == HloOpcode::kMinimum && primitive_util::IsIntegralType(element_type)) { - // max(integral, integral) - ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Min, output_address, - source, + // min(integral, integral) + auto opcode = primitive_util::IsSignedIntegralType(element_type) + ? llvm::AtomicRMWInst::Min + : llvm::AtomicRMWInst::UMin; + ir_builder_.CreateAtomicRMW(opcode, output_address, source, llvm::AtomicOrdering::SequentiallyConsistent); return true; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 607a366ac67d98d11c5141b390420aef00539dcd..718e27101e0dc2bfb1338f17979d452b08a2a376 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -118,8 +118,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { IrEmitterContext* ir_emitter_context, bool is_nested); // A convenient helper for calling HloToIrBindings::GetIrArray. - llvm_ir::IrArray GetIrArray(const HloInstruction& inst) { - return bindings_.GetIrArray(inst); + llvm_ir::IrArray GetIrArray(const HloInstruction& inst, + const ShapeIndex& shape_index = {}) { + return bindings_.GetIrArray(inst, shape_index); } // A convenient helper for calling HloToIrBindings::GetBasePointer. llvm::Value* GetBasePointer(const HloInstruction& inst) const { @@ -231,7 +232,7 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; Status HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 5fa2bfdd7e4301144054e0d4f41d1161e798176b..484de369675fb0188754d4bc2d187cbc6c92259b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -722,8 +722,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace -Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { if (ImplementedAsMemcpy(*copy)) { thunk_sequence_->emplace_back(BuildCopyThunk(copy)); return Status::OK(); @@ -731,7 +730,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, bool is_transpose_021; Shape reduced_input_shape, reduced_output_shape; std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = - IsTranspose021(operand->shape(), copy->shape()); + IsTranspose021(copy->operand(0)->shape(), copy->shape()); if (is_transpose_021 && reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { @@ -739,7 +738,8 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, VLOG(3) << "Emitting tiled 0-2-1 transposition"; constexpr int64 tile_size = 32; int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*operand).CastToShape(reduced_input_shape, &ir_builder_), + GetIrArray(*(copy->operand(0))) + .CastToShape(reduced_input_shape, &ir_builder_), GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), tile_size, &ir_builder_); UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(), @@ -747,7 +747,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, return Status::OK(); } - return IrEmitter::HandleCopy(copy, operand); + return IrEmitter::HandleCopy(copy); } Status IrEmitterUnnested::EmitColumnReduction( @@ -1648,7 +1648,7 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); return MakeUnique( - /*source_address=*/LiteralUtil::InternalData(operand->literal()), + /*source_address=*/operand->literal().InternalData(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ llvm_ir::ByteSizeOf(operand->shape(), @@ -1659,12 +1659,18 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); + + std::vector tuple_element_buffers; + for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { + BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(inst, {i}) + .ConsumeValueOrDie(); + tuple_element_buffers.push_back(buffer); + } + return MakeUnique( - /*destination_buffer=*/GetAllocationSlice(*inst), - /*mem_size=*/ - llvm_ir::ByteSizeOf(inst->shape(), - ir_emitter_context_->llvm_module()->getDataLayout()), - inst); + tuple_element_buffers, + /*destination_buffer=*/GetAllocationSlice(*inst), inst); } std::unique_ptr IrEmitterUnnested::BuildGemmThunk( @@ -1880,15 +1886,38 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { + const Shape& element_shape = hlo.IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo.shape(), {0}) + : hlo.shape(); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - hlo.shape(), ir_emitter_context_->device_description()); + element_shape, ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); - // Otherwise, emit a parallel loop that computes the partition that each - // thread is in charge of. - return ParallelLoopEmitter(element_generator, GetIrArray(hlo), - launch_dimensions, &ir_builder_) - .EmitLoop(); + if (!hlo.IsMultiOutputFusion()) { + return ParallelLoopEmitter(element_generator, GetIrArray(hlo), + launch_dimensions, &ir_builder_) + .EmitLoop(); + } + + // For multiple outputs fusion, we need to emit each operand and the root. + std::vector output_arrays; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { + output_arrays.push_back(GetIrArray(hlo, {i})); + } + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, + launch_dimensions, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (int64 i = 0; i < output_arrays.size(); ++i) { + tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); + } + ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); + // const HloInstruction* root = hlo.fused_expression_root(); + llvm_ir::EmitTuple( + GetIrArray(*hlo.fused_expression_root()->fusion_instruction()), + tuple_operand_ptrs, &ir_builder_); + return Status::OK(); } Status IrEmitterUnnested::EmitTargetElementLoop( diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 724549c0c4ef46e7526953f41439ea8eff71a779..1d1e5bee542c1c682fa74121934348e7e7a1b026 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -28,10 +28,10 @@ cc_library( "utils.h", ], deps = [ + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:gpu_backend_lib_flags", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index e03571a9672df62593318766fcecf414e0899ea1..881522a0298a8c8cd45d03a4863ad5e995bd4b13 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" @@ -134,13 +133,8 @@ static string GetSmName(std::pair compute_capability) { // from the input filename. string MakeNameForTempProduct(const std::string& input_filename, tensorflow::StringPiece extension) { - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - return tensorflow::io::JoinPath( - flags->dump_temp_products_to, - ReplaceFilenameExtension( - tensorflow::io::Basename(llvm_ir::AsString(input_filename)), - extension)); + return ReplaceFilenameExtension( + tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension); } // Initializes LLVM passes. Uses the PassRegistry mechanism. @@ -177,20 +171,16 @@ std::unique_ptr GetTargetMachine( .xla_enable_fast_math(), &target_options); - // Enable FMA synthesis if desired. - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (flags->fma) { - target_options.AllowFPOpFusion = FPOpFusion::Fast; - } + // Enable FMA synthesis. + target_options.AllowFPOpFusion = FPOpFusion::Fast; // Set the verbose assembly options. - target_options.MCOptions.AsmVerbose = flags->verbose_ptx_asm; + target_options.MCOptions.AsmVerbose = false; // The selection of codegen optimization level is copied from function // GetCodeGenOptLevel in //external/llvm/tools/opt/opt.cpp. CodeGenOpt::Level codegen_opt_level; - switch (flags->opt_level) { + switch (hlo_module_config.debug_options().xla_backend_optimization_level()) { case 1: codegen_opt_level = CodeGenOpt::Less; break; @@ -262,12 +252,10 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { // The extension is stripped by IrDumpingPassManager, so we need to // get creative to add a suffix. string module_id(llvm_ir::AsString(module->getModuleIdentifier())); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); IrDumpingPassManager codegen_passes( ReplaceFilenameExtension(tensorflow::io::Basename(module_id), "-nvptx.dummy"), - flags->dump_temp_products_to, flags->dump_ir_before_passes); + "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -345,36 +333,19 @@ StatusOr CompileModuleToPtx(llvm::Module* module, TF_RETURN_IF_ERROR( LinkLibdeviceIfNecessary(module, compute_capability, libdevice_dir_path)); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (!flags->dump_temp_products_to.empty()) { - string linked_filename = - MakeNameForTempProduct(module->getModuleIdentifier(), "linked.bc"); - LOG(INFO) << "dumping bitcode after linking libdevice to: " - << linked_filename; - EmitBitcodeToFile(*module, linked_filename); - } - // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass // can access it. - module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", flags->ftz); + module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", + hlo_module_config.debug_options().xla_gpu_ftz()); // If ftz is enabled, set it as an attribute on every function in the module. - if (flags->ftz) { + if (hlo_module_config.debug_options().xla_gpu_ftz()) { for (llvm::Function& fn : *module) { fn.addFnAttr("nvptx-f32ftz", "true"); } } - // Run IR-level optimizations. - if (flags->dump_ir_before_passes && flags->dump_temp_products_to.empty()) { - LOG(FATAL) << "--dump_ir_before_passes must be specified with " - "--dump_temp_products_to"; - } - - IrDumpingPassManager module_passes(module->getModuleIdentifier(), - flags->dump_temp_products_to, - flags->dump_ir_before_passes); + IrDumpingPassManager module_passes(module->getModuleIdentifier(), "", false); // Add an appropriate TargetLibraryInfo pass for the module's triple. llvm::TargetLibraryInfoWrapperPass* tliwp = @@ -406,8 +377,16 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // too. llvm::legacy::FunctionPassManager function_passes(module); - AddOptimizationPasses(flags->opt_level, /*size_level=*/0, - target_machine.get(), &module_passes, &function_passes); + int32 opt_level = + hlo_module_config.debug_options().xla_backend_optimization_level(); + + CHECK_GE(opt_level, 2) + << "The XLA GPU backend doesn't support unoptimized code generation"; + + AddOptimizationPasses(opt_level, + /*size_level=*/0, target_machine.get(), &module_passes, + &function_passes); + // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA // again after the standard optimization passes [http://b/13329423]. // TODO(jingyue): SROA may further expose more optimization opportunities, such @@ -415,7 +394,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // the inlining cost of a function). For now, running SROA already emits good // enough code for the evaluated benchmarks. We may want to run more // optimizations later. - if (flags->opt_level > 0) { + if (opt_level > 0) { // LLVM's optimizer turns on SROA when the optimization level is greater // than 0. We mimic this behavior here. module_passes.add(llvm::createSROAPass()); @@ -433,14 +412,6 @@ StatusOr CompileModuleToPtx(llvm::Module* module, function_passes.doFinalization(); module_passes.run(*module); - if (!flags->dump_temp_products_to.empty()) { - string optimized_filename = - MakeNameForTempProduct(module->getModuleIdentifier(), "optimized.bc"); - LOG(INFO) << "dumping bitcode after optimizations to: " - << optimized_filename; - EmitBitcodeToFile(*module, optimized_filename); - } - // Finally, produce PTX. return EmitModuleToPTX(module, target_machine.get()); } @@ -473,22 +444,6 @@ void GPUBackendInit() { // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (!flags->llvm_cl_opts.empty()) { - std::vector opts = - tensorflow::str_util::Split(flags->llvm_cl_opts, ','); - FeedLLVMWithFlags(opts); - } - - if (flags->llvm_dump_passes) { - // Enable LLVM pass debugging dump. LLVM dumps this information when a pass - // manager is initialized for execution. It's done to stderr (this is - // hardcoded within LLVM to the dbgs() stream, we can't change it from the - // outside). - FeedLLVMWithFlags({"-debug-pass=Arguments"}); - } - // Initialize the NVPTX target; it's the only target we link with, so call its // specific initialization functions instead of the catch-all InitializeAll*. LLVMInitializeNVPTXTarget(); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index a12a9a716829fbcf5b6348037fa723d5ddcc6930..b8c61620845a1434cc79dc9a8b00f89944e2ae95 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -61,7 +61,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + MakeUnique(Literal::Zero(element_type)))); input = computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape( /*operand_shape=*/input->shape(), @@ -127,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + MakeUnique(Literal::Zero(element_type)))); return computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape( /*operand_shape=*/kernel->shape(), @@ -242,9 +242,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(input->shape().element_type())))); HloInstruction* padded_input = computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape(input->shape(), padding->shape(), diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 65610b0995c512cc4a611ac650c581d0180d258d..d5543d296b3f0f6b19de90c42bea4f162057802a 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -36,6 +36,13 @@ ParallelLoopEmitter::ParallelLoopEmitter( : LoopEmitter(body_emitter, shape, ir_builder), launch_dimensions_(launch_dimensions) {} +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + tensorflow::gtl::ArraySlice target_arrays, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_arrays, ir_builder), + launch_dimensions_(launch_dimensions) {} + ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 73ca28cd842fe350ecd10885d983907e7288a350..d324a50698ea0d3e5e196347bd69c29b2ad27e3e 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -41,6 +41,12 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { const llvm_ir::IrArray& target_array, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + + ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + tensorflow::gtl::ArraySlice target_arrays, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 5065e7aedd08c591f33c152c6709823948db54f0..e4cfc6999f2da04dd7e7a34d854fdb3d75b8bfc6 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" -#include "tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" namespace xla { namespace gpu { @@ -46,10 +46,9 @@ namespace { // Returns whether the two HLOs can run concurrently, i.e., neither is a // transitive consumer of the other. -bool CanRunConcurrently( - const HloInstruction& a, const HloInstruction& b, - const HloComputation::ReachabilityMap& transitive_operands) { - return !transitive_operands.IsConnected(&a, &b); +bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b, + const HloReachabilityMap& reachability) { + return !reachability.IsConnected(&a, &b); } // Returns which existing stream to assign to `hlo`, or -1 if a stream is not @@ -58,7 +57,7 @@ bool CanRunConcurrently( // are topologically before `hlo`. int ComputeStreamToAssign( const HloInstruction& hlo, const StreamAssignment& stream_assignment, - const HloComputation::ReachabilityMap& transitive_operands, + const HloReachabilityMap& reachability, const std::vector& seen_gemms) { if (hlo.opcode() == HloOpcode::kParameter || hlo.opcode() == HloOpcode::kConstant) { @@ -66,9 +65,10 @@ int ComputeStreamToAssign( return -1; } - legacy_flags::StreamAssignmentFlags* flags = - legacy_flags::GetStreamAssignmentFlags(); - if (flags->xla_gpu_disable_multi_streaming) { + if (hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_disable_multi_streaming()) { return 0; } @@ -96,7 +96,7 @@ int ComputeStreamToAssign( for (const auto* seen_gemm : seen_gemms) { int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm); if (!forbidden_stream_numbers.count(stream_no) && - CanRunConcurrently(*seen_gemm, hlo, transitive_operands)) { + CanRunConcurrently(*seen_gemm, hlo, reachability)) { forbidden_stream_numbers.insert(stream_no); } } @@ -115,12 +115,12 @@ int ComputeStreamToAssign( std::unique_ptr AssignStreams(const HloModule& module) { auto stream_assignment = MakeUnique(); const HloComputation& computation = *module.entry_computation(); - std::unique_ptr transitive_operands = - computation.ComputeTransitiveOperands(); + std::unique_ptr reachability = + computation.ComputeReachability(); std::vector seen_gemms; for (const auto* hlo : computation.MakeInstructionPostOrder()) { int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment, - *transitive_operands, seen_gemms); + *reachability, seen_gemms); if (stream_no != -1) { stream_assignment->AssignStreamToHlo(hlo, stream_no); } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index 06b01d311dac5a6be78d7b8b16e7fcb39c189647..3034ed06b7eaff46a923b19cedb39f02d276c9f8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -37,8 +37,8 @@ namespace { // patterns to match. // // Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and HloOpcode -// of the operand. +// of type ExprTree). Operands can be added by specifying the index and +// HloOpcode of the operand. // // For example, the following computation: // @@ -197,10 +197,9 @@ class MatcherBase { return InvalidArgument("Must use S32 or S64 integral types."); } if (type == S32) { - *const_value = - static_cast(LiteralUtil::GetFirstElement(literal)); + *const_value = static_cast(literal.GetFirstElement()); } else if (type == S64) { - *const_value = LiteralUtil::GetFirstElement(literal); + *const_value = literal.GetFirstElement(); } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index e82491fd6f9f1158fc5b9e5bd475ef6ff97f2a7c..51d38f84212b01c08c33f1b648c579c5672769ba 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -41,7 +41,7 @@ class WhileTransformerTest : public HloTestBase { const int64 tuple_index, const int64 limit) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(limit))); + HloInstruction::CreateConstant(Literal::CreateR0(limit))); auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto induction_variable = @@ -64,8 +64,8 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, ind_var_tuple_index)); - auto inc = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(increment))); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(increment))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(data_tuple_index). @@ -88,12 +88,10 @@ class WhileTransformerTest : public HloTestBase { const int64 ind_var_tuple_index, const int64 ind_var_init) { auto builder = HloComputation::Builder(TestName() + ".While"); - auto induction_var_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(ind_var_init))); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto induction_var_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(ind_var_init))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); auto loop_state_init = ind_var_tuple_index == 0 ? builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu_transfer_manager.cc index 4b8d190a463ceb155f4fc8d3d22b47b9cbc8f23f..74f0bdb7db1847119c5bd75cc9fd9d921c6e162a 100644 --- a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu_transfer_manager.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -28,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -44,24 +44,85 @@ GpuTransferManager::GpuTransferManager() Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { const Shape& shape = literal.shape(); - VLOG(2) << "Transferring literal shape to infeed: " + VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - // TODO(b/30467474) handle tuples. - if (ShapeUtil::IsTuple(shape)) { - return Unimplemented("Infeed with a tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + if (!ShapeUtil::IsTuple(shape)) { + int64 size = GetByteSizeRequirement(shape); + return TransferBufferToInfeed(executor, size, literal.InternalData()); } - int64 size = GetByteSizeRequirement(shape); + if (ShapeUtil::IsNestedTuple(shape)) { + return Unimplemented( + "Infeed with a nested tuple shape is not supported: %s", + ShapeUtil::HumanString(literal.shape()).c_str()); + } + + // For a tuple, we transfer each of its elements to the device and + // enqueue the resulting destination device addresses with the + // infeed manager. + std::vector buffers; + buffers.reserve(literal.tuple_literals_size()); + auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { + for (gpu::InfeedBuffer* b : buffers) { + b->Done(); + } + }); + + for (const auto& tuple_element : literal.tuple_literals()) { + const Shape& tuple_element_shape = tuple_element.shape(); + int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); + TF_ASSIGN_OR_RETURN( + gpu::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, tuple_element_size, + tuple_element.InternalData())); + buffers.push_back(buffer); + } + + cleanup.release(); + return EnqueueBuffersToInfeed(executor, buffers); +} + +Status GpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, + const void* source) { + TF_ASSIGN_OR_RETURN(gpu::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, size, source)); + return EnqueueBuffersToInfeed(executor, {buffer}); +} + +Status GpuTransferManager::EnqueueBuffersToInfeed( + se::StreamExecutor* executor, std::vector buffers) { + gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); + se::Stream* stream = infeed_manager->GetStream(executor); + + // TODO(b/30467474): Since this stream is shared across different + // infeed requests, blocking on the stream might be + // heavy-handed. Figure out if finer-grained acknowledgement is + // possible. + if (!stream->BlockHostUntilDone()) { + for (gpu::InfeedBuffer* b : buffers) { + b->Done(); + } + return InternalError("Failed to complete data transfer on stream %p", + stream); + } + + infeed_manager->EnqueueBuffers(buffers); + + VLOG(2) << "Infeed data transferred"; + + return Status::OK(); +} + +StatusOr GpuTransferManager::TransferBufferToInfeedInternal( + se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return Unimplemented("Infeed shape is too large: %s needs %lld bytes", - ShapeUtil::HumanString(literal.shape()).c_str(), size); + return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); } if (size == 0) { - return Unimplemented("Infeed shape %s needs 0 bytes", - ShapeUtil::HumanString(literal.shape()).c_str()); + return InvalidArgument("Infeed shape needs 0 bytes"); } gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); @@ -71,21 +132,11 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, } gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size); - stream->ThenMemcpy(buffer->device_memory(), - LiteralUtil::InternalData(literal), size); + stream->ThenMemcpy(buffer->device_memory(), source, size); VLOG(2) << "Queued infeed data on stream " << stream; - if (!stream->BlockHostUntilDone()) { - buffer->Done(); - return InternalError("Failed to complete data transfer on stream %p", - stream); - } - - infeed_manager->EnqueueBuffer(buffer); - - VLOG(2) << "Infeed data transferred"; - return Status::OK(); + return buffer; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu_transfer_manager.h index 6dfe7ba0295aea699ca737e9dd47123b17cae3dc..9aa369c668364079504ead3491903e2590a142cc 100644 --- a/tensorflow/compiler/xla/service/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu_transfer_manager.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -37,8 +38,21 @@ class GpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; private: + // Initiates the infeed data transfers. InfeedBuffer->Done() must be + // called to clean up the memory allocated for InfeedBuffer. + StatusOr TransferBufferToInfeedInternal( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source); + + // Enqueues infeed data buffers with the infeed manager after their + // transfer completes. + Status EnqueueBuffersToInfeed(perftools::gputools::StreamExecutor* executor, + std::vector buffers); + TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); }; diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index cd00a41a03718502fcfa63e035639390b6fe6e07..049e8d80d80c835bca4a4d38592564ba82a3ecf9 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -47,7 +47,7 @@ HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); + HloInstruction::CreateConstant(Literal::CreateR0(0.5))); builder.AddInstruction(HloInstruction::CreateBinary( half->shape(), HloOpcode::kAdd, x_value, half)); return module->AddEmbeddedComputation(builder.Build()); @@ -118,7 +118,7 @@ std::unique_ptr MakeBigGraph() { auto rng = builder.AddInstruction( HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_computation = ScalarSumComputation(module.get()); builder.AddInstruction( HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); @@ -156,10 +156,9 @@ int main(int argc, char** argv) { auto module = xla::MakeBigGraph(); - printf("Graph URL: %s\n", - xla::hlo_graph_dumper::DumpGraph( - *module->entry_computation(), "Example computation", - /*show_addresses=*/false, /*show_layouts=*/false) - .c_str()); + printf("Graph URL: %s\n", xla::hlo_graph_dumper::DumpGraph( + *module->entry_computation(), + "Example computation", xla::DebugOptions()) + .c_str()); return 0; } diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 86f62accd3b524c3aa39c256a982bcf21edc1b25..840be603bf997f6f84e4c372c178fdf96f928f23 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -35,18 +35,26 @@ namespace { std::vector UniqueOperandSourceBuffers( const HloInstruction* instruction, const TuplePointsToAnalysis& points_to_analysis) { - FlatSet buffers; + std::vector buffers; for (const HloInstruction* operand : instruction->operands()) { - FlatSet sources = - points_to_analysis.GetPointsToSet(operand).CreateFlattenedSet(); - buffers.insert(sources.begin(), sources.end()); + points_to_analysis.GetPointsToSet(operand).ForEachElement( + [&](const ShapeIndex& /*index*/, + const std::vector& points_to) { + buffers.insert(buffers.end(), points_to.begin(), points_to.end()); + }); } - std::vector sorted(buffers.begin(), buffers.end()); - std::sort(sorted.begin(), sorted.end(), + + // Sort and then remove duplicates from buffers. + std::sort(buffers.begin(), buffers.end(), [](const LogicalBuffer* a, const LogicalBuffer* b) { return a->id() < b->id(); }); - return sorted; + buffers.erase(std::unique(buffers.begin(), buffers.end(), + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() == b->id(); + }), + buffers.end()); + return buffers; } } // namespace @@ -187,7 +195,7 @@ Status HeapSimulator::RunComputation( buffer->instruction()->opcode() != HloOpcode::kCopy && CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { + buffer->instruction(), buffer->index(), &points_to_analysis)) { ShareBuffer(buffer, operand_buffer, instruction); shared = true; break; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 60a0768a86b30ad5e8810a6f289008a9ee8c8a2e..ef9db8ba236f9923420c1f8b1a7423e0c036fb0f 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -173,7 +173,7 @@ class HeapSimulatorTest : public HloTestBase { TEST_F(HeapSimulatorTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); // Constants aren't assigned. See b/32248867 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0}); @@ -510,8 +510,7 @@ class HeapAlgorithmTestBase : public ::testing::Test { // other than the id and color. const LogicalBuffer* DummyLogicalBuffer() { const LogicalBuffer::Id id = buffers_.size(); - buffers_.emplace_back(MakeUnique(nullptr, ShapeIndex{}, id, - LogicalBuffer::Color(0))); + buffers_.emplace_back(MakeUnique(nullptr, ShapeIndex{}, id)); return buffers_.back().get(); } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 3b37f4a4b892497135c4dccc0082d244c1d8a27e..0c03d72752f97c201f2e209f99a4915ec97257ac 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -15,21 +15,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" -#include -#include +#include +#include #include #include #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -38,115 +38,16 @@ using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -void HloBuffer::AddValue(const HloValue& value) { - // If the value is already contained in this buffer, just return. - if (std::find(value_ids_.begin(), value_ids_.end(), value.id()) != - value_ids_.end()) { - return; - } - - value_ids_.push_back(value.id()); - - // Add all of the locations of the HloValue to this buffer. - for (const HloLocation& location : value.locations()) { - if (std::find(locations_.begin(), locations_.end(), location) == - locations_.end()) { - locations_.push_back(location); - } - } -} - -bool HloBuffer::operator==(const HloBuffer& other) const { - bool equal = id() == other.id(); - if (equal) { - // DCHECK because these comparisons are expensive (linear time). - DCHECK(value_ids() == other.value_ids()); - DCHECK(locations() == other.locations()); - } - return equal; -} - -string HloBuffer::ToString() const { - return StrCat("HloBuffer ", id_, ", values: ", Join(value_ids_, ", ")); -} - -std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { - out << buffer.ToString(); - return out; -} - -void HloBufferSet::AddBuffer(HloBuffer::Id buffer_id) { - if (std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id) == - buffer_ids_.end()) { - buffer_ids_.push_back(buffer_id); - } -} - -void HloBufferSet::RemoveBufferOrDie(HloBuffer::Id buffer_id) { - auto it = std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id); - CHECK(it != buffer_ids_.end()); - buffer_ids_.erase(it); -} - -string HloBufferSet::ToString() const { - return StrCat("HloBufferSet, buffers: ", Join(buffer_ids_, ", ")); -} - -std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set) { - out << buffer_set.ToString(); - return out; -} - -bool InstructionBufferSet::IsAmbiguous() const { - bool is_ambiguous = false; - ForEachElement( - [&is_ambiguous](const ShapeIndex& index, const HloBufferSet& buffer_set) { - is_ambiguous |= buffer_set.buffer_ids().size() > 1; - }); - return is_ambiguous; -} - -bool InstructionBufferSet::IsDistinct() const { - bool is_distinct = true; - tensorflow::gtl::FlatSet seen_ids; - ForEachElement([&is_distinct, &seen_ids](const ShapeIndex& index, - const HloBufferSet& buffer_set) { - for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) { - auto pair = seen_ids.insert(buffer_id); - if (!pair.second) { - is_distinct = false; - } - } - }); - return is_distinct; -} - -string InstructionBufferSet::ToString() const { - string out = - StrCat("InstructionBufferSet(", ShapeUtil::HumanString(shape()), ")\n"); - ForEachElement([this, &out](const ShapeIndex& index, - const HloBufferSet& value_set) { - StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); - }); - return out; -} - -std::ostream& operator<<(std::ostream& out, - const InstructionBufferSet& buffer_set) { - out << buffer_set.ToString(); - return out; -} - HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {} void HloAliasAnalysis::InitializeBufferSets() { - std::unordered_map value_to_buffer; + std::unordered_map value_to_buffer; // Initially define a buffer for every HloValue in the module. for (const HloValue* value : dataflow_analysis_->values()) { - HloBuffer& buffer = NewHloBuffer(); - buffer.AddValue(*value); - value_to_buffer[value->id()] = buffer.id(); + HloBuffer* buffer = NewHloBuffer(); + buffer->AddValue(*value); + value_to_buffer[value->id()] = buffer; } // Construct the Instruction buffer set to contain the HloBuffers for each @@ -160,9 +61,9 @@ void HloAliasAnalysis::InitializeBufferSets() { .ForEachElement( [this, &instruction, &value_to_buffer]( const ShapeIndex& index, const HloValueSet& value_set) { - for (HloValue::Id value_id : value_set.value_ids()) { - HloBuffer::Id buffer_id = value_to_buffer.at(value_id); - GetBufferSet(instruction.get(), index).AddBuffer(buffer_id); + for (const HloValue* value : value_set.values()) { + const HloBuffer* buffer = value_to_buffer.at(value->id()); + GetBufferSet(instruction.get(), index).AddBuffer(buffer); } }); } @@ -189,18 +90,18 @@ void HloAliasAnalysis::CombineBuffers( VLOG(4) << "Eliminating buffer: " << buffer_id; // Add all values held by the buffer-to-eliminate to the unified buffer. - for (HloValue::Id value_id : buffer.value_ids()) { - unified_buffer.AddValue(dataflow_analysis_->GetValue(value_id)); + for (const HloValue* value : buffer.values()) { + unified_buffer.AddValue(*value); } - // Iterate through all locations where the buffer-to-eliminate exists and + // Iterate through all positions where the buffer-to-eliminate exists and // replace it with the unified buffer. - for (const HloLocation& location : buffer.locations()) { - VLOG(4) << "Replacing in " << location; - GetBufferSet(location.instruction, location.index) + for (const HloPosition& position : buffer.positions()) { + VLOG(4) << "Replacing in " << position; + GetBufferSet(position.instruction, position.index) .RemoveBufferOrDie(buffer_id); - GetBufferSet(location.instruction, location.index) - .AddBuffer(unified_buffer.id()); + GetBufferSet(position.instruction, position.index) + .AddBuffer(&unified_buffer); } buffers_.erase(buffer_id); @@ -219,9 +120,9 @@ Status HloAliasAnalysis::Verify() const { TF_RETURN_IF_ERROR(instruction_buffer_set.ForEachElementWithStatus( [this, &buffers_in_sets](const ShapeIndex& index, const HloBufferSet& buffer_set) -> Status { - for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) { - TF_RET_CHECK(ContainsKey(buffers_, buffer_id)); - buffers_in_sets.insert(buffer_id); + for (const HloBuffer* buffer : buffer_set.buffers()) { + TF_RET_CHECK(ContainsKey(buffers_, buffer->id())); + buffers_in_sets.insert(buffer->id()); } return Status::OK(); })); @@ -240,7 +141,7 @@ void HloAliasAnalysis::FlattenInstructionBufferSets( VLOG(4) << "Flattening buffer sets of instructions: " << Join(instructions, ", ", [this](string* out, const HloInstruction* instruction) { - StrAppend(out, instruction->FullyQualifiedName()); + StrAppend(out, instruction->name()); }); if (instructions.size() < 2) { return; @@ -253,10 +154,11 @@ void HloAliasAnalysis::FlattenInstructionBufferSets( std::vector to_unify; for (const HloInstruction* instruction : instructions) { const HloBufferSet& buffer_set = GetBufferSet(instruction, index); - to_unify.insert(to_unify.end(), buffer_set.buffer_ids().begin(), - buffer_set.buffer_ids().end()); + for (const HloBuffer* buffer : buffer_set.buffers()) { + to_unify.push_back(buffer->id()); + } } - // Sort and uniquify buffers to combine. + // Sort and uniquify buffer ids to combine. std::sort(to_unify.begin(), to_unify.end()); to_unify.erase(std::unique(to_unify.begin(), to_unify.end()), to_unify.end()); @@ -265,14 +167,13 @@ void HloAliasAnalysis::FlattenInstructionBufferSets( }); } -HloBuffer& HloAliasAnalysis::NewHloBuffer() { +HloBuffer* HloAliasAnalysis::NewHloBuffer() { HloBuffer::Id buffer_id = next_buffer_id_++; - auto it_added = buffers_.emplace(std::piecewise_construct, + auto emplaced = buffers_.emplace(std::piecewise_construct, std::forward_as_tuple(buffer_id), std::forward_as_tuple(buffer_id)); - CHECK(it_added.second); - - return it_added.first->second; + CHECK(emplaced.second); + return &emplaced.first->second; } string HloAliasAnalysis::ToString() const { @@ -282,34 +183,18 @@ string HloAliasAnalysis::ToString() const { module_->computations()) { for (const std::unique_ptr& instruction : computation->instructions()) { - StrAppend(&out, " ", instruction->FullyQualifiedName(), ":\n"); - auto buffer_str = [this](const HloBuffer& buffer) { - return StrCat( - "Buffer ", buffer.id(), ", values: ", - Join(buffer.value_ids(), ", ", - [this](string* out, HloValue::Id value_id) { - StrAppend( - out, - dataflow_analysis_->GetValue(value_id).ToShortString()); - })); - }; + StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { GetInstructionBufferSet(instruction.get()) - .ForEachElement([this, &out, &buffer_str]( - const ShapeIndex& index, - const HloBufferSet& buffer_set) { + .ForEachElement([this, &out](const ShapeIndex& index, + const HloBufferSet& buffer_set) { StrAppend(&out, " tuple index ", index.ToString(), ":\n"); - for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) { - StrAppend(&out, " ", buffer_str(GetBuffer(buffer_id)), - "\n"); - } + StrAppend(&out, " ", buffer_set.ToString(), "\n"); }); } else { const HloBufferSet top_level_buffer_set = GetBufferSet(instruction.get()); - for (HloBuffer::Id buffer_id : top_level_buffer_set.buffer_ids()) { - StrAppend(&out, " ", buffer_str(GetBuffer(buffer_id)), "\n"); - } + StrAppend(&out, " ", top_level_buffer_set.ToString(), "\n"); } } } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 0fa35827b5ecbfd3987a17e60c3b395b36b16b2e..c70ec38c990ff5ea863f737b48cdcbcd49d513c2 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -16,182 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ -#include -#include #include -#include #include #include -#include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" namespace xla { -// A container which can hold one or more HloValues. An HLO buffer abstractly -// represents the allocation which HLO instructions write into and read -// from. Generally there is a one-to-one correspondence between HloBuffers and -// HloValue where each HloValue in the module is held in a unique HloBuffer. An -// exception is the while instruction which updates the loop state in-place. In -// this case, we have a single HloBuffer for each HloLocation in the loop state, -// but multiple HloValues. For example: -// -// %init = ... -// %while = While(%init, body, condition) -// -// body: -// %body_param = Param(0) -// ... -// %body_root = ... -// -// condition: -// %cond_param = Param(0) -// ... -// -// For simplicity, assume that %while is array-shaped. In this case, we have a -// single HloBuffer which holds the following HloValues: HloValue{%init}, -// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and -// HloValue{%cond_param}. -// -// HloBuffers may appear at different HloLocations in the module mirroring the -// same propery of HloValues. For example: -// -// %sub = Sub(...) -// %add = Add(...) -// %tuple = Tuple(%add, %sub) -// %gte = GetTupleElement(%tuple, 0) -// -// In this case, the HloBuffer containing %add appears at the following -// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and -// HloLocation{%gte, {}}. -// -// Different HloLocations which share the same HloBuffer indicate mandatory -// aliasing in the HLO module. These locations must share the same memory -// allocation for correctness (the backends rely on this property). This differs -// from incidental aliasing introduced by memory reuse in BufferAssignment where -// different instructions may happen to get the same allocation. -class HloBuffer { - public: - using Id = int64; - - HloBuffer(int64 id) : id_(id) {} - - // Return the unique identifier for this HloBuffer. - int64 id() const { return id_; } - - // Add a value to the set of values held by this buffer. Also adds the - // HloLocations of the value to the locations vector of the buffer. If the - // buffer already contains this value, then this method is a nop. - void AddValue(const HloValue& value); - - // Return the IDs of all values contained in this buffer. - const std::vector& value_ids() const { return value_ids_; } - - // Return the locations (output of which instruction and at what index) where - // the buffer is used. This is exactly the union of the locations of the - // HloValues contained by the buffer. - const std::vector& locations() const { return locations_; } - - string ToString() const; - - bool operator==(const HloBuffer& other) const; - bool operator!=(const HloBuffer& other) const { return !(*this == other); } - - private: - // Unique identifier for this HloBuffer. - const Id id_; - - // The set of values contained in the this buffer. - std::vector value_ids_; - - // The set of locations where this buffer is used. - std::vector locations_; -}; - -std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); - -// A class representing the set of possible HloBuffers at a particular -// HloLocation (shape index in the output of an instruction) in the XLA -// graph. In most cases, the buffer set will have a single HloBuffer indicating -// that the HloBuffer which appears at that particular location is known -// unambiguously at compile-time. However, tuple-shaped Select instructions can -// introduce ambiguity as the tuple elements of the operands are passed by -// reference into the output of the Select. For example: -// -// %pred = ... -// %tuple0 = Tuple(%a, %b) -// %tuple1 = Tuple(%x, %y) -// %select = Select(%pred, %tuple0, %tuple1) -// -// In this case the HloBufferSet at HloLocation{%select, {0}} contains the -// HloBuffer holding %a and the HloBuffer holding %x. -class HloBufferSet { - public: - HloBufferSet() = default; - - // Add the given buffer to this buffer set. If the buffer already exists in - // the set, then this is a NOP. - void AddBuffer(HloBuffer::Id buffer_id); - - // Removes the given buffer from this buffer set. CHECK fails in the buffer is - // not contained in this set. - void RemoveBufferOrDie(HloBuffer::Id buffer_id); - - // Returns the unique buffer in this set. CHECK fails if the set does not - // contain exactly one buffer. - HloBuffer::Id GetUniqueBufferId() const { - CHECK_EQ(buffer_ids().size(), 1); - return buffer_ids()[0]; - } - - // Returns the IDs of the HloBuffers contained in this buffer set. - const std::vector& buffer_ids() const { return buffer_ids_; } - - string ToString() const; - - private: - // The IDs of the HloBuffers containted in this buffer set. - std::vector buffer_ids_; -}; - -std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set); - -// A class collecting the HloBuffers in the output of an HLO instruction. For -// array-shaped instructions, an InstructionBufferSet trivially holds a single -// HloBufferSet. Tuple-shaped InstructionBufferSets hold multiple -// HloBufferSets. -class InstructionBufferSet : public ShapeTree { - public: - InstructionBufferSet(const Shape& shape) : ShapeTree(shape) {} - - // Returns true if any HloBufferSet contained in this InstructionBufferSet - // is not a singleton. - bool IsAmbiguous() const; - - // Returns true if any HloBuffer appears in more than one HloBufferSet - // contained in this InstructionBufferSet. - bool IsDistinct() const; - - string ToString() const; -}; - -std::ostream& operator<<(std::ostream& out, - const InstructionBufferSet& buffer_set); - class HloAliasAnalysis { public: static StatusOr> Run(HloModule* module); @@ -204,7 +45,7 @@ class HloAliasAnalysis { InstructionBufferSet& GetInstructionBufferSet( const HloInstruction* instruction); - // Return the HloBufferSet for the given location. + // Return the HloBufferSet for the given position. const HloBufferSet& GetBufferSet(const HloInstruction* instruction, const ShapeIndex& index = {}) const; HloBufferSet& GetBufferSet(const HloInstruction* instruction, @@ -218,15 +59,15 @@ class HloAliasAnalysis { return buffers_.at(buffer_id); } - // Returns the unique buffer at the given location. CHECK fails if the buffer - // set at that location does not contain exactly one buffer. + // Returns the unique buffer at the given position. CHECK fails if the buffer + // set at that position does not contain exactly one buffer. const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, const ShapeIndex& index = {}) const { - return GetBuffer(GetBufferSet(instruction, index).GetUniqueBufferId()); + return GetBufferSet(instruction, index).GetUniqueBuffer(); } HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, const ShapeIndex& index = {}) { - return GetBuffer(GetBufferSet(instruction, index).GetUniqueBufferId()); + return GetBuffer(GetBufferSet(instruction, index).GetUniqueBuffer().id()); } // Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This @@ -242,8 +83,8 @@ class HloAliasAnalysis { protected: HloAliasAnalysis(HloModule* module); - // Creates a new HloBuffer and returns a reference to it. - HloBuffer& NewHloBuffer(); + // Returns a new HloBuffer. + HloBuffer* NewHloBuffer(); // Construct the initial set of buffer sets where an HloBuffer is created for // each HloValue in the module. @@ -282,7 +123,9 @@ class HloAliasAnalysis { // The underlying dataflow analysis used by this alias analysis. std::unique_ptr dataflow_analysis_; - // The map of all HloBuffers in the module. + // The map of all HloBuffers in the module. We pass around pointers to the + // mapped HloBuffers, so the underlying container must keep them valid despite + // mutations touching other map entries. std::unordered_map buffers_; // A map from instruction to its InstructionBufferSet. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 24c467d411b93be32bd884a8bb92ef288d9c2f10..3c5b2e03b762be2247a5c58b13915ae883c93622 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -37,22 +37,22 @@ using ::testing::UnorderedElementsAre; class HloAliasAnalysisTest : public HloTestBase { protected: - HloAliasAnalysisTest() : module_(TestName()) {} + HloAliasAnalysisTest() : module_(CreateNewModule()) {} // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. const HloAliasAnalysis& RunAnalysis() { - analysis_ = HloAliasAnalysis::Run(&module_).ConsumeValueOrDie(); + analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie(); return *analysis_; } - // Return a vector of the buffers in the buffer set at the current location. + // Return a vector of the buffers in the buffer set at the current position. std::vector GetBuffersAt(const HloInstruction* instruction, const ShapeIndex& index = {}) const { std::vector buffers; - for (HloBuffer::Id buffer_id : - analysis_->GetBufferSet(instruction, index).buffer_ids()) { - buffers.push_back(analysis_->GetBuffer(buffer_id)); + for (const HloBuffer* buffer : + analysis_->GetBufferSet(instruction, index).buffers()) { + buffers.push_back(*buffer); } return buffers; } @@ -60,24 +60,41 @@ class HloAliasAnalysisTest : public HloTestBase { // Return a vector containing all of the HloValues in the given buffer. std::vector GetValuesInBuffer(const HloBuffer& buffer) { std::vector values; - for (HloValue::Id value_id : buffer.value_ids()) { - values.push_back(analysis_->dataflow_analysis().GetValue(value_id)); + for (const HloValue* value : buffer.values()) { + values.push_back(*value); } return values; } - // Return the HloValue defined at the given location. + // Return the HloValue defined at the given position. const HloValue& GetValueDefinedAt(const HloInstruction* instruction, const ShapeIndex& index = {}) const { return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index); } - const HloValue& GetUniqueValueInBuffer(const HloBuffer& buffer) const { - CHECK_EQ(buffer.value_ids().size(), 1); - return analysis_->dataflow_analysis().GetValue(buffer.value_ids()[0]); + // Returns true if any values held in the same buffer interfere. Generally, in + // the compiler pipeline copy-insertion will guarantee that this interference + // never occurs, but HLO graphs with interference can be explicitly + // constructed. + bool AnyValuesInSameBufferInterfere() { + DependencyHloOrdering ordering(module_.get()); + for (const HloBuffer* buffer : analysis_->buffers()) { + for (const HloValue* value_a : buffer->values()) { + for (const HloValue* value_b : buffer->values()) { + if (*value_a != *value_b && + analysis_->dataflow_analysis().MayInterfere(*value_a, *value_b, + ordering)) { + VLOG(1) << *value_a << " interferes with " << *value_b + << " in buffer: " << *buffer; + return true; + } + } + } + } + return false; } - HloModule module_; + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -87,12 +104,12 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { // Test the analysis on a single binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -101,12 +118,14 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { // All of the buffer sets should trivially contain a single buffer containing // a single value. for (const HloInstruction* instruction : {constant1, constant2, add}) { - EXPECT_EQ(GetUniqueValueInBuffer(analysis.GetUniqueBufferAt(instruction)), + EXPECT_EQ(analysis.GetUniqueBufferAt(instruction).GetUniqueValue(), GetValueDefinedAt(instruction)); } EXPECT_FALSE(analysis.GetInstructionBufferSet(add).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(add).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleAndGtes) { @@ -124,22 +143,19 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); EXPECT_EQ(analysis.buffers().size(), 4); // Verify the expected aliasing of the tuple elements. - EXPECT_EQ( - GetUniqueValueInBuffer(analysis.GetUniqueBufferAt(tuple, /*index=*/{})), - GetValueDefinedAt(tuple, /*index=*/{})); - EXPECT_EQ( - GetUniqueValueInBuffer(analysis.GetUniqueBufferAt(tuple, /*index=*/{0})), - GetValueDefinedAt(param0)); - EXPECT_EQ( - GetUniqueValueInBuffer(analysis.GetUniqueBufferAt(tuple, /*index=*/{1})), - GetValueDefinedAt(param1)); + EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}).GetUniqueValue(), + GetValueDefinedAt(tuple, /*index=*/{})); + EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{0}).GetUniqueValue(), + GetValueDefinedAt(param0)); + EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{1}).GetUniqueValue(), + GetValueDefinedAt(param1)); // The tuple operand, tuple element, and result of the GTE instruction should // all be the same buffer. @@ -148,14 +164,16 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { EXPECT_EQ(analysis.GetUniqueBufferAt(param0), analysis.GetUniqueBufferAt(gte0)); - // Verify the locations of an aliased buffer. + // Verify the positions of an aliased buffer. EXPECT_THAT( - analysis.GetUniqueBufferAt(param0).locations(), - UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, - HloLocation{gte0, {}})); + analysis.GetUniqueBufferAt(param0).positions(), + UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}}, + HloPosition{gte0, {}})); EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, NondistinctTuple) { @@ -168,17 +186,19 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { // param0 is included twice in the tuple. auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({param0, param1, param0})); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); EXPECT_THAT( - analysis.GetUniqueBufferAt(param0).locations(), - UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, - HloLocation{tuple, {2}})); + analysis.GetUniqueBufferAt(param0).positions(), + UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}}, + HloPosition{tuple, {2}})); EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SingleCall) { @@ -192,31 +212,33 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); // Verify aliasing of the kCall operands and the subcomputation parameters. - EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(), - UnorderedElementsAre(HloLocation{constant1, {}}, - HloLocation{subparam0, {}})); - EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(), - UnorderedElementsAre(HloLocation{constant2, {}}, - HloLocation{subparam1, {}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(), + UnorderedElementsAre(HloPosition{constant1, {}}, + HloPosition{subparam0, {}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(), + UnorderedElementsAre(HloPosition{constant2, {}}, + HloPosition{subparam1, {}})); // The subcomputation root and the kCall itself should alias. EXPECT_THAT( - analysis.GetUniqueBufferAt(add).locations(), - UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call, {}})); + analysis.GetUniqueBufferAt(add).positions(), + UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}})); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { @@ -229,35 +251,35 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); - EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(), - UnorderedElementsAre(HloLocation{constant1, {}}, - HloLocation{subparam0, {}})); - EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(), - UnorderedElementsAre(HloLocation{constant2, {}}, - HloLocation{subparam1, {}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(), + UnorderedElementsAre(HloPosition{constant1, {}}, + HloPosition{subparam0, {}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(), + UnorderedElementsAre(HloPosition{constant2, {}}, + HloPosition{subparam1, {}})); // The 'add' (root of the subcomputation) aliases the two call instruction, // and the first parameter of the subcomputation because 'call1' it is passed // as an argument to the subcomputation in 'call2'. EXPECT_THAT( - analysis.GetUniqueBufferAt(add).locations(), - UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call1, {}}, - HloLocation{subparam0, {}}, HloLocation{call2, {}})); + analysis.GetUniqueBufferAt(add).positions(), + UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}}, + HloPosition{subparam0, {}}, HloPosition{call2, {}})); EXPECT_THAT(GetBuffersAt(subparam0), UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), @@ -269,6 +291,8 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { EXPECT_FALSE(analysis.GetInstructionBufferSet(subparam1).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam1).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SingleWhile) { @@ -303,48 +327,48 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); auto body_tuple = body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); // Condition computation trivially returns a constant "false". auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); - // Verify the locations of the aliased while buffers. - EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).locations(), + // Verify the positions of the aliased while buffers. + EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).positions(), UnorderedElementsAre( - HloLocation{tuple, {}}, HloLocation{xla_while, {}}, - HloLocation{body_param, {}}, HloLocation{body_tuple, {}}, - HloLocation{cond_param, {}})); - EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).locations(), + HloPosition{tuple, {}}, HloPosition{xla_while, {}}, + HloPosition{body_param, {}}, HloPosition{body_tuple, {}}, + HloPosition{cond_param, {}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).positions(), UnorderedElementsAre( - HloLocation{constant1, {}}, HloLocation{tuple, {0}}, - HloLocation{xla_while, {0}}, HloLocation{body_param, {0}}, - HloLocation{body_element_0, {}}, HloLocation{body_tuple, {0}}, - HloLocation{cond_param, {0}})); - EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).locations(), + HloPosition{constant1, {}}, HloPosition{tuple, {0}}, + HloPosition{xla_while, {0}}, HloPosition{body_param, {0}}, + HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}}, + HloPosition{cond_param, {0}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).positions(), UnorderedElementsAre( - HloLocation{constant2, {}}, HloLocation{tuple, {1}}, - HloLocation{xla_while, {1}}, HloLocation{body_param, {1}}, - HloLocation{body_element_1, {}}, HloLocation{add, {}}, - HloLocation{body_tuple, {1}}, HloLocation{cond_param, {1}})); + HloPosition{constant2, {}}, HloPosition{tuple, {1}}, + HloPosition{xla_while, {1}}, HloPosition{body_param, {1}}, + HloPosition{body_element_1, {}}, HloPosition{add, {}}, + HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}})); EXPECT_THAT( GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})), @@ -356,6 +380,8 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { GetValueDefinedAt(body_param, {1}), GetValueDefinedAt(cond_param, {1}), GetValueDefinedAt(add))); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SequentialWhiles) { @@ -392,21 +418,21 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -415,7 +441,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); auto xla_while2 = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -449,13 +475,21 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - auto cond_builder = HloComputation::Builder("condition"); - cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); - HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + auto build_cond_computation = [&tuple_shape]() { + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + return cond_builder.Build(); + }; + // Build separate condition computations so the call graph is flat. The + // callgraph is always flattened in the compiler pipeline, and the flattened + // callgraph enables representative interference analysis. + HloComputation* condition1 = + module_->AddEmbeddedComputation(build_cond_computation()); + HloComputation* condition2 = + module_->AddEmbeddedComputation(build_cond_computation()); // Element 0 passes transparently through the body. auto inner_builder = HloComputation::Builder("inner_body"); @@ -470,7 +504,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { inner_builder.AddInstruction( HloInstruction::CreateTuple({inner_element_0, add})); HloComputation* inner_body = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); // Element 1 passes transparently through the body. auto outer_builder = HloComputation::Builder("outer_body"); @@ -485,20 +519,20 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { auto outer_tuple = outer_builder.AddInstruction( HloInstruction::CreateTuple({negate, outer_element_1})); auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, condition, inner_body, outer_tuple)); + tuple_shape, condition1, inner_body, outer_tuple)); HloComputation* outer_body = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( - HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); - module_.AddEntryComputation(builder.Build()); + HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple)); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -515,6 +549,8 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { analysis.GetUniqueBufferAt(nested_while, /*index=*/{1})); EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), analysis.GetUniqueBufferAt(inner_element_1)); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { @@ -548,32 +584,32 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 2)); body_builder.AddInstruction(HloInstruction::CreateTuple( {body_element_1, body_element_2, body_element_0})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2, constant3})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); - // The swizzling while makes most locations in the module alias leaving only 3 + // The swizzling while makes most positions in the module alias leaving only 3 // HloBuffers. EXPECT_THAT( analysis.buffers(), @@ -593,6 +629,10 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { analysis.GetUniqueBufferAt(constant2)); EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), analysis.GetUniqueBufferAt(constant3)); + + // All elements in of the loop state tuple are forced into the same buffer + // resulting liveness interference. + EXPECT_TRUE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleSelect) { @@ -600,15 +640,15 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { // instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -627,24 +667,24 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, select12, select34)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); // Verify the buffer sets of each select. - EXPECT_THAT(analysis.GetBufferSet(select11, /*index=*/{0}).buffer_ids(), - UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1).id())); - EXPECT_THAT(analysis.GetBufferSet(select12, /*index=*/{0}).buffer_ids(), - UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1).id(), - analysis.GetUniqueBufferAt(constant2).id())); - EXPECT_THAT(analysis.GetBufferSet(select34, /*index=*/{0}).buffer_ids(), - UnorderedElementsAre(analysis.GetUniqueBufferAt(constant3).id(), - analysis.GetUniqueBufferAt(constant4).id())); - EXPECT_THAT(analysis.GetBufferSet(select1234, /*index=*/{0}).buffer_ids(), - UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1).id(), - analysis.GetUniqueBufferAt(constant2).id(), - analysis.GetUniqueBufferAt(constant3).id(), - analysis.GetUniqueBufferAt(constant4).id())); + EXPECT_THAT(GetBuffersAt(select11, /*index=*/{0}), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1))); + EXPECT_THAT(GetBuffersAt(select12, /*index=*/{0}), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(constant2))); + EXPECT_THAT(GetBuffersAt(select34, /*index=*/{0}), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant3), + analysis.GetUniqueBufferAt(constant4))); + EXPECT_THAT(GetBuffersAt(select1234, /*index=*/{0}), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(constant2), + analysis.GetUniqueBufferAt(constant3), + analysis.GetUniqueBufferAt(constant4))); EXPECT_FALSE(analysis.GetInstructionBufferSet(select11).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsAmbiguous()); @@ -655,6 +695,8 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { @@ -688,22 +730,22 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kNegate, body_element)); body_builder.AddInstruction(HloInstruction::CreateTuple({negate})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -713,7 +755,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, select)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -736,17 +778,21 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { EXPECT_TRUE(analysis.GetInstructionBufferSet(select).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(xla_while).IsDistinct()); + + // The two operands of the select get flattened into the same buffer resulting + // in liveness interference. + EXPECT_TRUE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, Bitcast) { // Bitcasting a value should not produce a new buffer. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1ad5daf79ce9e71652b8b6cba6e36ba57a838bc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_buffer.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +void HloBuffer::AddValue(const HloValue& value) { + // If the value is already contained in this buffer, just return. + if (!values_.AddValue(&value)) { + return; + } + + // Add all of the positions of the HloValue to this buffer. + for (const HloPosition& position : value.positions()) { + if (std::find(positions_.begin(), positions_.end(), position) == + positions_.end()) { + positions_.push_back(position); + } + } +} + +bool HloBuffer::operator==(const HloBuffer& other) const { + bool equal = id() == other.id(); + if (equal) { + // DCHECK because these comparisons are expensive (linear time). + DCHECK(values_ == other.values_); + DCHECK(positions() == other.positions()); + } + return equal; +} + +string HloBuffer::ToString() const { + return StrCat("HloBuffer ", id_, ", values: ", values_.ToString()); +} + +std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { + out << buffer.ToString(); + return out; +} + +void HloBufferSet::AddBuffer(const HloBuffer* buffer) { + auto it = std::lower_bound(buffers_.begin(), buffers_.end(), buffer, + HloBuffer::IdLessThan); + if (it == buffers_.end() || (*it)->id() != buffer->id()) { + buffers_.insert(it, buffer); + } +} + +void HloBufferSet::RemoveBufferOrDie(HloBuffer::Id buffer_id) { + auto it = std::lower_bound(buffers_.begin(), buffers_.end(), buffer_id, + [](const HloBuffer* buffer, HloBuffer::Id id) { + return buffer->id() < id; + }); + CHECK(it != buffers_.end() && (*it)->id() == buffer_id) + << "HloBuffer " << buffer_id << " doesn't exist in set: " << ToString(); + buffers_.erase(it); +} + +string HloBufferSet::ToString() const { + return StrCat( + "HloBufferSet, buffers: ", + Join(buffers_, ", ", [](string* result, const HloBuffer* buffer) { + result->append(buffer->ToString()); + })); +} + +std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set) { + out << buffer_set.ToString(); + return out; +} + +bool InstructionBufferSet::IsAmbiguous() const { + bool is_ambiguous = false; + ForEachElement( + [&is_ambiguous](const ShapeIndex& index, const HloBufferSet& buffer_set) { + is_ambiguous |= buffer_set.buffers().size() > 1; + }); + return is_ambiguous; +} + +bool InstructionBufferSet::IsDistinct() const { + bool is_distinct = true; + tensorflow::gtl::FlatSet seen_ids; + ForEachElement([&is_distinct, &seen_ids](const ShapeIndex& /*index*/, + const HloBufferSet& buffer_set) { + for (const HloBuffer* buffer : buffer_set.buffers()) { + auto pair = seen_ids.insert(buffer->id()); + if (!pair.second) { + is_distinct = false; + } + } + }); + return is_distinct; +} + +string InstructionBufferSet::ToString() const { + string out = + StrCat("InstructionBufferSet(", ShapeUtil::HumanString(shape()), ")\n"); + ForEachElement([this, &out](const ShapeIndex& index, + const HloBufferSet& value_set) { + StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); + }); + return out; +} + +std::ostream& operator<<(std::ostream& out, + const InstructionBufferSet& buffer_set) { + out << buffer_set.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h new file mode 100644 index 0000000000000000000000000000000000000000..f42d2f7720e44978fdbac8783e1b4b70e3bf3a01 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// A container which can hold one or more HloValues. An HLO buffer abstractly +// represents the allocation which HLO instructions write into and read +// from. Generally there is a one-to-one correspondence between HloBuffers and +// HloValue where each HloValue in the module is held in a unique HloBuffer. An +// exception is the while instruction which updates the loop state in-place. In +// this case, we have a single HloBuffer for each HloPosition in the loop state, +// but multiple HloValues. For example: +// +// %init = ... +// %while = While(%init, body, condition) +// +// body: +// %body_param = Param(0) +// ... +// %body_root = ... +// +// condition: +// %cond_param = Param(0) +// ... +// +// For simplicity, assume that %while is array-shaped. In this case, we have a +// single HloBuffer which holds the following HloValues: HloValue{%init}, +// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and +// HloValue{%cond_param}. +// +// HloBuffers may appear at different HloPositions in the module mirroring the +// same propery of HloValues. For example: +// +// %sub = Sub(...) +// %add = Add(...) +// %tuple = Tuple(%add, %sub) +// %gte = GetTupleElement(%tuple, 0) +// +// In this case, the HloBuffer containing %add appears at the following +// positions: HloPosition{%add, {}}, HloPosition{%tuple, {0}}, and +// HloPosition{%gte, {}}. +// +// Different HloPositions which share the same HloBuffer indicate mandatory +// aliasing in the HLO module. These positions must share the same memory +// allocation for correctness (the backends rely on this property). This differs +// from incidental aliasing introduced by memory reuse in BufferAssignment where +// different instructions may happen to get the same allocation. +class HloBuffer { + public: + using Id = int64; + + // Predicate comparing HloBuffers by increasing id, useful for std::sort. + static bool IdLessThan(const HloBuffer* a, const HloBuffer* b) { + return a->id() < b->id(); + } + + // Predicate comparing HloBuffers by equal id, useful for std::unique. + static bool IdEqual(const HloBuffer* a, const HloBuffer* b) { + return a->id() == b->id(); + } + + HloBuffer(Id id) : id_(id) {} + + // Return the unique identifier for this HloBuffer. + Id id() const { return id_; } + + // Add a value to the set of values held by this buffer. Also adds the + // HloPositions of the value to the positions vector of the buffer. If the + // buffer already contains this value, then this method is a nop. + void AddValue(const HloValue& value); + + // Return all values contained in this buffer. + const std::vector& values() const { + return values_.values(); + } + + // Return the unique HLO value in the buffer. CHECK fails if the buffer does + // not contain exactly one value. + const HloValue& GetUniqueValue() const { return values_.GetUniqueValue(); } + + // Return the positions (output of which instruction and at what index) where + // the buffer is used. This is exactly the union of the positions of the + // HloValues contained by the buffer. + const std::vector& positions() const { return positions_; } + + string ToString() const; + + bool operator==(const HloBuffer& other) const; + bool operator!=(const HloBuffer& other) const { return !(*this == other); } + + private: + // Unique identifier for this HloBuffer. + const Id id_; + + // The set of values contained in this buffer. + HloValueSet values_; + + // The set of positions where this buffer is used. + std::vector positions_; +}; + +std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); + +// A class representing the set of possible HloBuffers at a particular +// HloPosition (shape index in the output of an instruction) in the XLA +// graph. In most cases, the buffer set will have a single HloBuffer indicating +// that the HloBuffer which appears at that particular position is known +// unambiguously at compile-time. However, tuple-shaped Select instructions can +// introduce ambiguity as the tuple elements of the operands are passed by +// reference into the output of the Select. For example: +// +// %pred = ... +// %tuple0 = Tuple(%a, %b) +// %tuple1 = Tuple(%x, %y) +// %select = Select(%pred, %tuple0, %tuple1) +// +// In this case the HloBufferSet at HloPosition{%select, {0}} contains the +// HloBuffer holding %a and the HloBuffer holding %x. +class HloBufferSet { + public: + HloBufferSet() = default; + + // Add the given buffer to this buffer set. If the buffer already exists in + // the set, then this is a NOP. + void AddBuffer(const HloBuffer* buffer); + + // Removes the given buffer from this buffer set. CHECK fails in the buffer is + // not contained in this set. + void RemoveBufferOrDie(HloBuffer::Id buffer_id); + + // Returns the unique buffer in this set. CHECK fails if the set does not + // contain exactly one buffer. + const HloBuffer& GetUniqueBuffer() const { + CHECK_EQ(buffers_.size(), 1); + return *buffers_[0]; + } + + // Returns the vector of HloBuffers in the set, sorted by HloBuffer::Id. + const std::vector& buffers() const { return buffers_; } + + string ToString() const; + + private: + // HloBuffers sorted by HloBuffer::Id. + std::vector buffers_; +}; + +std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set); + +// A class collecting the HloBuffers in the output of an HLO instruction. For +// array-shaped instructions, an InstructionBufferSet trivially holds a single +// HloBufferSet. Tuple-shaped InstructionBufferSets hold multiple +// HloBufferSets. +class InstructionBufferSet : public ShapeTree { + public: + InstructionBufferSet(const Shape& shape) : ShapeTree(shape) {} + + // Returns true if any HloBufferSet contained in this InstructionBufferSet + // is not a singleton. + bool IsAmbiguous() const; + + // Returns true if any HloBuffer appears in more than one HloBufferSet + // contained in this InstructionBufferSet. + bool IsDistinct() const; + + string ToString() const; +}; + +std::ostream& operator<<(std::ostream& out, + const InstructionBufferSet& buffer_set); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ff76cc7bf67e29d489f9b32e4fce94ce28b59992..119cf7dde5a79add498c94ea6f0cf385f4363764 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -66,22 +67,25 @@ HloComputation::HloComputation( HloInstruction* root_instruction, bool is_fusion_computation) : name_(name), root_instruction_(root_instruction), - is_fusion_computation_(is_fusion_computation), - instruction_name_uniquer_(/*separator=*/".") { + is_fusion_computation_(is_fusion_computation) { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; for (auto& instruction : *instructions) { if (instruction->opcode() == HloOpcode::kParameter) { int64 param_no = instruction->parameter_number(); - CHECK_GE(param_no, 0); - CHECK_LT(param_no, param_instructions_.size()); - CHECK_EQ(nullptr, param_instructions_[param_no]); + CHECK(param_no >= 0 && param_no < parameter_count) + << "\nERROR: invalid parameter number. Expected [0, " + << parameter_count << "), got " << param_no; + CHECK(param_instructions_[param_no] == nullptr) + << "\nERROR: parameter number " << param_no + << " already allocated in this computation"; param_instructions_[param_no] = instruction.get(); } root_found |= instruction.get() == root_instruction_; AddInstructionInternal(std::move(instruction)); } - CHECK(root_found); + CHECK(root_found) + << "\nERROR: root instruction is not present in computation."; } HloInstruction* HloComputation::AddInstruction( @@ -94,8 +98,9 @@ HloInstruction* HloComputation::AddInstruction( HloInstruction* HloComputation::AddInstructionInternal( std::unique_ptr instruction) { - // Generate a unique name for the instruction. - instruction->UniquifyName(&instruction_name_uniquer_); + if (parent() != nullptr) { + instruction->UniquifyName(&parent()->instruction_name_uniquer()); + } Reparent(instruction.get()); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = @@ -206,7 +211,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( Status HloComputation::RemoveInstruction(HloInstruction* instruction) { VLOG(2) << "Removing instruction " << instruction->name() << " from computation " << name(); - TF_RET_CHECK(IsRemovable(instruction)); + TF_RET_CHECK(IsRemovable(instruction)) + << "cannot remove instruction: " << instruction->ToString(); TF_RET_CHECK(root_instruction() != instruction) << "cannot remove root instruction " << instruction->name(); TF_RET_CHECK(instruction->user_count() == 0) @@ -537,67 +543,46 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, return RemoveInstructionAndUnusedOperands(old_instruction); } -HloComputation::ReachabilityMap::ReachabilityMap( - const std::list& all_instructions) { - const int n = all_instructions.size(); - int next_id = 0; - for (const auto* hlo : all_instructions) { - ids_[hlo] = next_id; - next_id++; - } - DCHECK_EQ(n, ids_.size()); // instructions should be unique - matrix_.Reset(n * n); -} - -void HloComputation::ReachabilityMap::SetReachable(const HloInstruction* a, - const HloInstruction* b) { - const int id_a = FindOrDie(ids_, a); - const int id_b = FindOrDie(ids_, b); - matrix_.set(id_a * ids_.size() + id_b); -} +std::unique_ptr HloComputation::ComputeReachability() + const { + const std::list all = MakeInstructionPostOrder(); + auto result = MakeUnique(all); -bool HloComputation::ReachabilityMap::IsReachable( - const HloInstruction* a, const HloInstruction* b) const { - const int id_a = FindOrDie(ids_, a); - const int id_b = FindOrDie(ids_, b); - return matrix_.get(id_a * ids_.size() + id_b); + std::vector inputs; + for (const HloInstruction* hlo : all) { + inputs.assign(hlo->operands().begin(), hlo->operands().end()); + inputs.insert(inputs.end(), hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + result->SetReachabilityToUnion(inputs, hlo); + } + return result; } -bool HloComputation::ReachabilityMap::IsConnected( - const HloInstruction* a, const HloInstruction* b) const { - const int id_a = FindOrDie(ids_, a); - const int id_b = FindOrDie(ids_, b); - return matrix_.get(id_a * ids_.size() + id_b) || - matrix_.get(id_b * ids_.size() + id_a); -} +void HloComputation::UpdateReachabilityThroughInstruction( + const HloInstruction* instruction, HloReachabilityMap* reachability_map) { + std::queue worklist; + worklist.push(instruction); -void HloComputation::ReachabilityMap::SetReachableAndTransitiveClosure( - const HloInstruction* a, const HloInstruction* b) { - const int id_a = FindOrDie(ids_, a); - const int id_b = FindOrDie(ids_, b); - const int n = ids_.size(); - matrix_.set(id_a * n + id_b); + std::vector inputs; - // Copy transitive set for b into entries for a - for (int i = 0; i < n; i++) { - if (matrix_.get(id_b * n + i)) { - matrix_.set(id_a * n + i); - } - } -} + while (!worklist.empty()) { + const HloInstruction* item = worklist.front(); + worklist.pop(); -std::unique_ptr -HloComputation::ComputeTransitiveOperands() const { - const auto all = MakeInstructionPostOrder(); - auto result = MakeUnique(all); + inputs.assign(item->operands().begin(), item->operands().end()); + inputs.insert(inputs.end(), item->control_predecessors().begin(), + item->control_predecessors().end()); - // Fill in the dependency bit matrix - for (const auto* hlo : all) { - for (const HloInstruction* operand : hlo->operands()) { - result->SetReachableAndTransitiveClosure(hlo, operand); + if (reachability_map->SetReachabilityToUnion(inputs, item)) { + // Add immediate successors to worklist. + for (const HloInstruction* user : item->users()) { + worklist.push(user); + } + for (const HloInstruction* succ : item->control_successors()) { + worklist.push(succ); + } } } - return result; } std::vector HloComputation::CollectUnreachableRoots() const { @@ -609,6 +594,12 @@ std::vector HloComputation::CollectUnreachableRoots() const { unreachable_roots.push_back(instruction.get()); } } + VLOG(3) << "Unreachable roots:" + << tensorflow::str_util::Join( + unreachable_roots, "\n\t", + [](string* out, const HloInstruction* hlo) { + tensorflow::strings::StrAppend(out, hlo->ToString()); + }); return unreachable_roots; } @@ -617,6 +608,7 @@ Status HloComputation::Accept(DfsHloVisitor* visitor) const { // visited root, which would invalidate iterators if the unreachable roots // weren't computed ahead of time. for (HloInstruction* root : CollectUnreachableRoots()) { + VLOG(3) << "Traversing unreachable root: " << root->ToString(); // Call FinishVisit only at the end. TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false)); } @@ -643,9 +635,15 @@ Status HloComputation::AcceptWithOperandOrder( Status HloComputation::AcceptOrdered( DfsHloVisitor* visitor, const std::vector& order) const { + VLOG(3) << "Accepting visitor with order."; + for (HloInstruction* root : CollectUnreachableRoots()) { + TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) + << root->ToString(); + } TF_RET_CHECK(order.size() == instruction_count()); std::unordered_set visited; for (const HloInstruction* instruction : order) { + VLOG(3) << "Visiting ordered: " << instruction->ToString(); TF_RET_CHECK(instruction_iterators_.count(instruction) == 1) << "Instruction " << instruction->name() << " is not in computation " << name(); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 39074b24e41f073b6b5b60880cbd1f6e2e9b399d..cf6df3c94f885816d20530161822f7cc948a30be 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -29,11 +29,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -153,9 +153,18 @@ class HloComputation { // this order, definitions of values always appear before their uses. std::list MakeInstructionPostOrder() const; - // Computes and returns the mapping from HLO to its transitive operands. - class ReachabilityMap; - std::unique_ptr ComputeTransitiveOperands() const; + // Computes and returns the reachability between HLO instructions in the + // computation. The returned HloReachabilityMap is constructed such that + // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a + // directed path (from producer to consumer) from 'a' to 'b'. Both data + // dependencies (operands) and control dependencies are considered for + // reachability. Trivially an instruction is reachable from itself. + std::unique_ptr ComputeReachability() const; + + // Updates the given reachabilty map after the immediate predecessor set + // (operands and control predecessors) of 'instruction' has changed. + void UpdateReachabilityThroughInstruction( + const HloInstruction* instruction, HloReachabilityMap* reachability_map); int64 instruction_count() const { return instructions_.size(); } @@ -308,34 +317,6 @@ class HloComputation { TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); }; -class HloComputation::ReachabilityMap { - public: - // Sets up an empty reachable matrix for the full set of - // instructions specified in "all_instructions" - explicit ReachabilityMap(const std::list& all_instructions); - // Sets entry so that IsReachable(a, b) will return true - void SetReachable(const HloInstruction* a, const HloInstruction* b); - - // Sets IsReachable(a_inst, b_inst) as well as IsReachable(a_inst, trans) - // for all "trans" s.t. "IsReachable(b_inst, trans)" is true - void SetReachableAndTransitiveClosure(const HloInstruction* a_inst, - const HloInstruction* b_inst); - - // Returns true if "b" is reachable from "a" - bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; - - // Returns true if "b" is reachable from "a" or "a" is reachable from "b" - bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; - - private: - friend class HloComputation; - - // dense id assignment from HloInstruction* to number - tensorflow::gtl::FlatMap ids_; - // matrix_(a,b) is true iff b is reachable from a - tensorflow::core::Bitmap matrix_; -}; - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 5d49c83e2d070cb9e5409a62983940225b903b2b..4a4a8556692b3da6f92f8333397a9537ade2f8ef 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -110,7 +110,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { // Test GetInstructionPostOrder for a computation with one instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto computation = builder.Build(); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); @@ -121,7 +121,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { // instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( @@ -136,7 +136,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { // Test GetInstructionPostOrder for a computation with a trace instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto trace = @@ -155,13 +155,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto computation = builder.Build(); EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -173,11 +173,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -197,11 +197,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { // computation has multiple roots (dead code). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); // Add three disconnected add expressions. builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -248,7 +248,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { // Test that DeepCopyInstruction properly copies an array. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto computation = builder.Build(); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -260,9 +260,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { // Test that DeepCopyInstruction properly copies a tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -280,7 +280,7 @@ TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( @@ -303,7 +303,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { // twice. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto dead_negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -326,9 +326,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { TEST_F(HloComputationTest, CloneWithControlDependency) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -352,6 +352,105 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } +TEST_F(HloComputationTest, Reachability) { + // Test reachability of a non-trivial computation: + // + // const1 const2 + // | | + // | +-------+ + // | | | + // add .. negate + // | . | + // | .... exp + // | | + // +---+ +-+---+ + // | | | + // multiply copy + // + // There is a control dependency from 'add' to 'exp'. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); + + auto computation = builder.Build(/*root_instruction=*/mul); + + TF_CHECK_OK(add->AddControlDependencyTo(exp)); + auto reachability = computation->ComputeReachability(); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_TRUE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_TRUE(reachability->IsReachable(constant1, copy)); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_TRUE(reachability->IsReachable(constant2, negate)); + EXPECT_TRUE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_TRUE(reachability->IsReachable(constant2, copy)); + + EXPECT_FALSE(reachability->IsReachable(exp, constant1)); + EXPECT_FALSE(reachability->IsReachable(exp, constant2)); + EXPECT_FALSE(reachability->IsReachable(exp, add)); + EXPECT_FALSE(reachability->IsReachable(exp, negate)); + EXPECT_TRUE(reachability->IsReachable(exp, exp)); + EXPECT_TRUE(reachability->IsReachable(exp, mul)); + EXPECT_TRUE(reachability->IsReachable(exp, copy)); + + EXPECT_FALSE(reachability->IsReachable(mul, constant1)); + EXPECT_FALSE(reachability->IsReachable(mul, constant2)); + EXPECT_FALSE(reachability->IsReachable(mul, add)); + EXPECT_FALSE(reachability->IsReachable(mul, negate)); + EXPECT_FALSE(reachability->IsReachable(mul, exp)); + EXPECT_TRUE(reachability->IsReachable(mul, mul)); + EXPECT_FALSE(reachability->IsReachable(mul, copy)); + + EXPECT_TRUE(reachability->IsConnected(constant1, copy)); + EXPECT_TRUE(reachability->IsConnected(copy, constant1)); + EXPECT_FALSE(reachability->IsConnected(negate, add)); + EXPECT_FALSE(reachability->IsConnected(add, negate)); + + // Remove the control dependency then update and verify the reachability map + ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); + computation->UpdateReachabilityThroughInstruction(exp, reachability.get()); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_FALSE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_FALSE(reachability->IsReachable(constant1, copy)); + + // Change a use within the graph then update and verify the reachability map + ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); + computation->UpdateReachabilityThroughInstruction(negate, reachability.get()); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_FALSE(reachability->IsReachable(constant2, negate)); + EXPECT_FALSE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_FALSE(reachability->IsReachable(constant2, copy)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 93f448e701853e271646c9f8fb0d42f49489b756..1a2eed5f6026dc6a27e4879e63ecc378d2064d47 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -42,6 +42,9 @@ StatusOr HloConstantFolding::Run(HloModule* module) { bool changed = false; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } for (auto instruction : computation->MakeInstructionPostOrder()) { // Skip dead code. if (instruction->user_count() == 0 && @@ -58,6 +61,13 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Broadcasts dramatically increase the size of constants with is often + // detrimental to performance and memory capacity so do not fold + // broadcasts. + if (instruction->opcode() == HloOpcode::kBroadcast) { + continue; + } + std::unique_ptr result = evaluator->TryEvaluate(instruction); // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 31b81052cb2b00e602b94b9d84525a623caa741e..3ae499d5e0c37532ae0a83a4a247cab85fd2c84e 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -41,7 +41,7 @@ using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); @@ -51,19 +51,18 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), + EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42); } TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -73,19 +72,18 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), + EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42.0f); } TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0f, 19.0f}))); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({42.0f, 19.0f}))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); @@ -95,16 +93,12 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {0}), - 42); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {1}), - 19); + EXPECT_EQ(computation->root_instruction()->literal().Get({0}), 42); + EXPECT_EQ(computation->root_instruction()->literal().Get({1}), 19); } TEST_F(HloConstantFoldingTest, Concatenate) { @@ -126,7 +120,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { for (auto csize : test_config.concat_sizes) { dimensions[test_config.concat_dimension] = csize; concat_size += csize; - auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); + auto literal = Literal::CreateFromDimensions(F32, dimensions); HloInstruction* insn = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); operands.push_back(insn); @@ -139,7 +133,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -154,9 +148,9 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_start[] = {4, 2, 3, 1, 5}; const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; - TF_ASSIGN_OR_ASSERT_OK(auto literal, - LiteralTestUtil::CreateRandomLiteral( - ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); @@ -166,7 +160,7 @@ TEST_F(HloConstantFoldingTest, Slice) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -177,10 +171,10 @@ TEST_F(HloConstantFoldingTest, Slice) { TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; - TF_ASSIGN_OR_ASSERT_OK(auto literal, - LiteralTestUtil::CreateRandomLiteral( - ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = LiteralUtil::CloneToUnique(*literal); + TF_ASSERT_OK_AND_ASSIGN(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -191,7 +185,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -200,12 +194,10 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; bool matched = true; - LiteralUtil::EachCell( - root->literal(), + root->literal().EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - matched = matched && (value == LiteralUtil::Get(*literal_clone, - rindexes)); + matched = matched && (value == literal_clone->Get(rindexes)); }); EXPECT_TRUE(matched); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 38cc74b0f1e640d4e72188416258d9b262053152..efc3b1c49c6ddca15a3615dc12551c4557ec841c 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -25,34 +25,56 @@ limitations under the License. namespace xla { +constexpr char HloCostAnalysis::kFlopsKey[]; +constexpr char HloCostAnalysis::kTranscendentalsKey[]; +constexpr char HloCostAnalysis::kBytesAccessedKey[]; +constexpr char HloCostAnalysis::kSecondsKey[]; + +HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size) + : HloCostAnalysis(shape_size, {}) {} + +HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size, + const Properties& per_second_rates) + : shape_size_(shape_size), per_second_rates_(per_second_rates) {} + Status HloCostAnalysis::Preprocess(HloInstruction* hlo) { // Set current instruction cost values to reasonable default values. Each - // handler can overwrite these values. In Postprocess, these value are + // handler can overwrite these values. In Postprocess, these values are // accumulated and written to the per-instruction maps. - current_flop_count_ = 0; - current_transcendental_count_ = 0; + current_properties_.clear(); + current_should_compute_bottleneck_time_ = true; - // The default element count for an instruction is the sum of elements in the - // operands and output. The default ShapeUtil::ByteSizeOf does not handle - // opaque types. - current_bytes_accessed_ = shape_size_(hlo->shape()); + // The default number of bytes accessed for an instruction is the sum of the + // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not + // handle opaque types. + float bytes_accessed = shape_size_(hlo->shape()); for (const HloInstruction* operand : hlo->operands()) { - current_bytes_accessed_ += shape_size_(operand->shape()); + bytes_accessed += shape_size_(operand->shape()); } + current_properties_[kBytesAccessedKey] = bytes_accessed; return Status::OK(); } Status HloCostAnalysis::Postprocess(HloInstruction* hlo) { - // Accumulate cost values and write into per-instruction maps. - flop_count_ += current_flop_count_; - hlo_to_flop_count_[hlo] = current_flop_count_; - - transcendental_count_ += current_transcendental_count_; - hlo_to_transcendental_count_[hlo] = current_transcendental_count_; + if (current_should_compute_bottleneck_time_) { + // Compute the time as the time of the bottleneck, i.e. the slowest property + // given the per-second rate of each property. + float max_seconds = 0.0f; + for (const auto& property : current_properties_) { + if (property.first != kSecondsKey) { + max_seconds = std::max( + max_seconds, + property.second / GetProperty(property.first, per_second_rates_)); + } + } + current_properties_[kSecondsKey] = max_seconds; + } - bytes_accessed_ += current_bytes_accessed_; - hlo_to_bytes_accessed_[hlo] = current_bytes_accessed_; + TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second); + for (const auto& property : current_properties_) { + properties_sum_[property.first] += property.second; + } return Status::OK(); } @@ -65,25 +87,39 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { auto opcode = hlo_instruction->opcode(); // We treat the two opcodes (kExp, kPower) as transcendental operations. if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower) { - current_transcendental_count_ = computation_count; + current_properties_[kTranscendentalsKey] = computation_count; } else { // Note: transcendental operations are considered a separate category from // FLOPs. - current_flop_count_ = computation_count; + current_properties_[kFlopsKey] = computation_count; } return Status::OK(); } +/*static*/ float HloCostAnalysis::GetProperty(const string& key, + const Properties& properties) { + auto key_value = properties.find(key); + return key_value == properties.end() ? 0.0f : key_value->second; +} + +/*static*/ float HloCostAnalysis::GetPropertyForHlo( + const HloInstruction& hlo, const string& key, + const HloToProperties& hlo_to_properties) { + auto it = hlo_to_properties.find(&hlo); + if (it == hlo_to_properties.end()) { + return 0.0f; + } else { + return GetProperty(key, it->second); + } +} + Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* operand) { + HloOpcode opcode) { return HandleElementwiseOp(hlo); } Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { + HloOpcode opcode) { return HandleElementwiseOp(hlo); } @@ -100,14 +136,18 @@ Status HloCostAnalysis::HandleClamp(HloInstruction* clamp, return HandleElementwiseOp(clamp); } +Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo) { + return HandleElementwiseOp(hlo); +} + Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) { - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(HloInstruction* constant, const Literal& literal) { - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -115,7 +155,7 @@ Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) { // GetTupleElement forwards a pointer and does not touch each element in the // output. - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -153,8 +193,9 @@ Status HloCostAnalysis::HandleTuple( tensorflow::gtl::ArraySlice operands) { // The tuple instruction only gathers pointers from inputs (it doesn't iterate // through them). The memory touched is then only the size of the output - // buffer. - current_bytes_accessed_ = shape_size_(tuple->shape()); + // index table of the tuple. + + current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape()); return Status::OK(); } @@ -164,13 +205,11 @@ Status HloCostAnalysis::HandleConcatenate( return Status::OK(); } -Status HloCostAnalysis::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { +Status HloCostAnalysis::HandleConvert(HloInstruction* convert) { return HandleElementwiseOp(convert); } -Status HloCostAnalysis::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status HloCostAnalysis::HandleCopy(HloInstruction* copy) { return Status::OK(); } @@ -194,7 +233,7 @@ Status HloCostAnalysis::HandleDot(HloInstruction* dot, } // We count an FMA operation as 2 floating point operations. - current_flop_count_ = kFmaFlops * fma_count; + current_properties_[kFlopsKey] = kFmaFlops * fma_count; return Status::OK(); } @@ -210,16 +249,17 @@ Status HloCostAnalysis::HandleMap( HloInstruction* map, tensorflow::gtl::ArraySlice operands, HloComputation* function, tensorflow::gtl::ArraySlice /*static_operands*/) { - // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + // Compute properties of the mapped function. + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this Map operation. - int64 element_count = ShapeUtil::ElementsIn(map->shape()); - current_transcendental_count_ = - element_count * visitor.transcendental_count(); - current_flop_count_ = element_count * visitor.flop_count(); + const int64 element_count = ShapeUtil::ElementsIn(map->shape()); + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * element_count; + } + } return Status::OK(); } @@ -227,16 +267,17 @@ Status HloCostAnalysis::HandleReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(reduce->shape()); - current_flop_count_ = reduction_count * visitor.flop_count(); - current_transcendental_count_ = - reduction_count * visitor.transcendental_count(); + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * reduction_count; + } + } return Status::OK(); } @@ -244,55 +285,63 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, HloInstruction* operand, const Window& window, HloComputation* function) { - // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + // Compute the properties of the reduction function. + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this ReduceWindow operation. For each - // output element, (window_size - 1) number of user computations are applied. - auto output_size = ShapeUtil::ElementsIn(reduce_window->shape()); - int64 window_size = 1; + // output element there are window_size - 1 reductions to perform. + int64 window_element_count = 1; for (const auto& dimension : window.dimensions()) { - window_size *= dimension.size(); + window_element_count *= dimension.size(); + } + const int64 output_element_count = + ShapeUtil::ElementsIn(reduce_window->shape()); + const int64 reduction_count = + (window_element_count - 1) * output_element_count; + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * reduction_count; + } } - current_flop_count_ = output_size * (window_size - 1) * visitor.flop_count(); - current_transcendental_count_ = - output_size * (window_size - 1) * visitor.transcendental_count(); return Status::OK(); } Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { - // Compute the cost of the select and scatter function. - HloInstruction* select = instruction->select()->root_instruction(); - HloCostAnalysis select_visitor(shape_size_); - TF_RETURN_IF_ERROR(select->Accept(&select_visitor)); - HloInstruction* scatter = instruction->scatter()->root_instruction(); - HloCostAnalysis scatter_visitor(shape_size_); - TF_RETURN_IF_ERROR(scatter->Accept(&scatter_visitor)); + // Compute the properties of the select and scatter function. + // Compute the properties of the reduction function. + TF_ASSIGN_OR_RETURN(const Properties select_properties, + ProcessSubcomputation(instruction->select())); + TF_ASSIGN_OR_RETURN(const Properties scatter_properties, + ProcessSubcomputation(instruction->scatter())); // Compute the cost of all elements for this operation. For each scatter - // source element, (window_size - 1) number of select computations and 1 - // scatter computation are applied. + // source element there are window_size - 1 select computations to perform and + // 1 scatter computation to perform. const auto source = instruction->operand(1); const auto source_element_count = ShapeUtil::ElementsIn(source->shape()); - int64 window_size = 1; + int64 window_element_count = 1; for (const auto& dimension : instruction->window().dimensions()) { - window_size *= dimension.size(); + window_element_count *= dimension.size(); + } + const int64 select_count = source_element_count * (window_element_count - 1); + for (const auto& property : select_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] += property.second * select_count; + } + } + for (const auto& property : scatter_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] += + property.second * source_element_count; + } } - current_flop_count_ = - source_element_count * ((window_size - 1) * select_visitor.flop_count() + - scatter_visitor.flop_count()); - current_transcendental_count_ = - source_element_count * - ((window_size - 1) * select_visitor.transcendental_count() + - scatter_visitor.transcendental_count()); return Status::OK(); } Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) { // A bitcast does no computation and touches no memory. - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -314,6 +363,17 @@ Status HloCostAnalysis::HandleReshape(HloInstruction* reshape) { return Status::OK(); } +Status HloCostAnalysis::HandleBatchNormTraining( + HloInstruction* batchNormTraining) { + // TODO(b/62294698): Implement cost analysis for batch-norm-training. + return Status::OK(); +} + +Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batchNormGrad) { + // TODO(b/62294698): Implement cost analysis for batch-norm-grad. + return Status::OK(); +} + Status HloCostAnalysis::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } @@ -326,12 +386,13 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution, const int64 output_features = convolution->shape().dimensions(dnums.feature_dimension()); - // For each output element, we do one fma per element in the - // kernel at some given output feature index. + // For each output element, we do one fma per element in the kernel at some + // given output feature index. const int64 fmas_per_output_element = ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features; const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); - current_flop_count_ = output_elements * fmas_per_output_element * kFmaFlops; + current_properties_[kFlopsKey] = + output_elements * fmas_per_output_element * kFmaFlops; return Status::OK(); } @@ -341,7 +402,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) { // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - current_flop_count_ = ShapeUtil::ElementsIn(crs->shape()); + current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape()); return Status::OK(); } @@ -350,31 +411,43 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random, // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume // the cost of each RNG is same as a transcendental operation. - current_transcendental_count_ = ShapeUtil::ElementsIn(random->shape()); + current_properties_[kTranscendentalsKey] = + ShapeUtil::ElementsIn(random->shape()); return Status::OK(); } Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { - // Compute the cost of the fused expression. - HloInstruction* fused_expression_root = fusion->fused_expression_root(); - // Don't compute sizes inside of fused ops. We don't use the size here and the - // operations inside might not have a layout. - HloCostAnalysis visitor([](const Shape&) { return 0; }); - TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor)); + // Compute the properties of the fused expression and attribute them to the + // fusion node. Use a dummy shape_size to avoid any errors from trying to + // calculate the size of a shape that does not have a layout, since nodes + // inside fusion nodes do not necessarily have a layout assigned. + ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; }; + TF_ASSIGN_OR_RETURN( + current_properties_, + ProcessSubcomputation(fusion->fused_instructions_computation(), + &shape_size)); + + // Fusion nodes that produce a tuple also produce the entries in the tuple. + // Ignore the memory accessed inside fused ops, since fusion is supposed to + // prevent intermediate data from touching slow memory. + current_properties_[kBytesAccessedKey] = 0; + ShapeUtil::ForEachSubshape( + fusion->shape(), + [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { + current_properties_[kBytesAccessedKey] += shape_size_(subshape); + }); + + for (const HloInstruction* operand : fusion->operands()) { + current_properties_[kBytesAccessedKey] += shape_size_(operand->shape()); + } - // Attribute the cost of the fused expression to the fusion node. - current_transcendental_count_ = visitor.transcendental_count(); - current_flop_count_ = visitor.flop_count(); return Status::OK(); } Status HloCostAnalysis::HandleCall(HloInstruction* call) { - HloCostAnalysis computation_visitor(shape_size_); - TF_RETURN_IF_ERROR(call->to_apply()->Accept(&computation_visitor)); - - current_flop_count_ = computation_visitor.flop_count(); - current_transcendental_count_ = computation_visitor.transcendental_count(); - current_bytes_accessed_ = computation_visitor.bytes_accessed(); + TF_ASSIGN_OR_RETURN(current_properties_, + ProcessSubcomputation(call->to_apply())); + current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -382,34 +455,38 @@ Status HloCostAnalysis::HandleCustomCall( HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) { - return Unimplemented("custom-call"); + return Unimplemented("Custom-call is not implemented for HLO cost analysis."); } Status HloCostAnalysis::HandleSort(HloInstruction* sort, HloInstruction* operand_instruction) { - // The cost of sort is implementation dependent, so cannot determine at HLO - // level. Assume comparison based N*log(N) sorting. + // This assumes a comparison based N*log(N) algorithm. As for all ops, the + // actual properties of the op depend on the backend implementation. int64 elements = ShapeUtil::ElementsIn(operand_instruction->shape()); - current_flop_count_ = elements * tensorflow::Log2Ceiling(elements); + current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements); return Status::OK(); } Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { - // Since the number of iterations of the while node is not statically - // determined, we cannot precisely compute the cost of a while node. For now - // compute the cost of a single iteration. - // TODO(b/26346211): Improve the cost analysis for while node. - HloCostAnalysis body_visitor(shape_size_); - TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&body_visitor)); - HloCostAnalysis condition_visitor(shape_size_); - TF_RETURN_IF_ERROR(xla_while->while_condition()->Accept(&condition_visitor)); + // Since the number of iterations of the while node will not always be + // something that we can statically analyze, we cannot precisely compute the + // cost of a while node. For now compute the cost of a single iteration. + // + // TODO(b/26346211): Improve the cost analysis for while nodes. + TF_ASSIGN_OR_RETURN(const Properties body_properties, + ProcessSubcomputation(xla_while->while_body())); - current_flop_count_ = - body_visitor.flop_count() + condition_visitor.flop_count(); - current_transcendental_count_ = body_visitor.transcendental_count() + - condition_visitor.transcendental_count(); - current_bytes_accessed_ = - body_visitor.bytes_accessed() + condition_visitor.bytes_accessed(); + TF_ASSIGN_OR_RETURN(const Properties condition_properties, + ProcessSubcomputation(xla_while->while_condition())); + + current_properties_.clear(); + for (const auto& property : body_properties) { + current_properties_[property.first] += property.second; + } + for (const auto& property : condition_properties) { + current_properties_[property.first] += property.second; + } + current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -418,19 +495,42 @@ Status HloCostAnalysis::FinishVisit(HloInstruction* root) { return Status::OK(); } +float HloCostAnalysis::flop_count() const { + return GetProperty(kFlopsKey, properties_sum_); +} + +float HloCostAnalysis::transcendental_count() const { + return GetProperty(kTranscendentalsKey, properties_sum_); +} + +float HloCostAnalysis::bytes_accessed() const { + return GetProperty(kBytesAccessedKey, properties_sum_); +} + +float HloCostAnalysis::seconds() const { + return GetProperty(kSecondsKey, properties_sum_); +} + int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { - auto it = hlo_to_flop_count_.find(&hlo); - return it == hlo_to_flop_count_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_); } int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const { - auto it = hlo_to_transcendental_count_.find(&hlo); - return it == hlo_to_transcendental_count_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_); } int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { - auto it = hlo_to_bytes_accessed_.find(&hlo); - return it == hlo_to_bytes_accessed_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_); +} + +StatusOr HloCostAnalysis::ProcessSubcomputation( + HloComputation* computation, const ShapeSizeFunction* shape_size) { + if (shape_size == nullptr) { + shape_size = &shape_size_; + } + HloCostAnalysis visitor(*shape_size, per_second_rates_); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.properties(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index b2c40f75ca4e833f1f5529977564b0e3a7ca25b1..3c2e9503aa626d9b9777d6650f219458a915f57d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -36,17 +36,22 @@ namespace xla { // operations separately from transcendental operations. class HloCostAnalysis : public DfsHloVisitor { public: + // Each HLO is associated to a vector of properties with the indices given + // below. Sub-classes can add further properties. + typedef std::map Properties; + static constexpr char kFlopsKey[] = "flops"; + static constexpr char kTranscendentalsKey[] = "transcendentals"; + static constexpr char kBytesAccessedKey[] = "bytes accessed"; + static constexpr char kSecondsKey[] = "seconds"; + // shape_size is a function which returns the size in bytes of the top-level // buffer of a shape. using ShapeSizeFunction = std::function; - explicit HloCostAnalysis(const ShapeSizeFunction& shape_size) - : shape_size_(shape_size) {} - - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand) override; - Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) override; + explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); + + Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override; + Status HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, @@ -58,14 +63,14 @@ class HloCostAnalysis : public DfsHloVisitor { HloInstruction* lhs, HloInstruction* rhs) override; Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) override; + Status HandleReducePrecision(HloInstruction* hlo) override; Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) override; Status HandleSend(HloInstruction* send) override; Status HandleRecv(HloInstruction* recv) override; - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleConvert(HloInstruction* convert) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, @@ -83,6 +88,8 @@ class HloCostAnalysis : public DfsHloVisitor { HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function_handle) override; + Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override; + Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, @@ -119,48 +126,88 @@ class HloCostAnalysis : public DfsHloVisitor { Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; - // Returns the amount of computations in the graph. - int64 flop_count() const { return flop_count_; } - int64 transcendental_count() const { return transcendental_count_; } + // Set the rates used to calculate the time taken by the computation. These + // need to be set before visiting starts. + void set_flops_per_second(float value) { + per_second_rates_[kFlopsKey] = value; + } + void set_transcendentals_per_second(float value) { + per_second_rates_[kTranscendentalsKey] = value; + } + void set_bytes_per_second(float value) { + per_second_rates_[kBytesAccessedKey] = value; + } + + // Returns properties for the computation. + float flop_count() const; + float transcendental_count() const; + float bytes_accessed() const; + float seconds() const; // Returns the respective cost computed for a particular HLO instruction, or 0 // if the HLO was not found to have a cost in the analysis. int64 flop_count(const HloInstruction& hlo) const; int64 transcendental_count(const HloInstruction& hlo) const; - - // Returns the number of bytes read/written. int64 bytes_accessed(const HloInstruction& hlo) const; - int64 bytes_accessed() const { return bytes_accessed_; } + float seconds(const HloInstruction& hlo) const; + + const Properties& properties() const { return properties_sum_; } + const float property(const string& key) const { + return GetProperty(key, properties()); + } + + protected: + typedef std::unordered_map HloToProperties; - private: // An FMA counts as two floating point operations in these analyzes. static constexpr int64 kFmaFlops = 2; + HloCostAnalysis(const ShapeSizeFunction& shape_size, + const Properties& per_second_rates); + + // Returns the properties computed from visiting the computation rooted at the + // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null, + // otherwise uses shape_size_. + StatusOr ProcessSubcomputation( + HloComputation* computation, + const ShapeSizeFunction* shape_size = nullptr); + // Utility function to handle all element-wise operations. Status HandleElementwiseOp(HloInstruction* hlo_instruction); + // Returns 0.0f if the key is not present in the properties. Otherwise, + // returns the value that the key maps to from the properties parameter. + static float GetProperty(const string& key, const Properties& properties); + + // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key + // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that + // the key maps to in the properties of the given hlo. + static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, + const HloToProperties& hlo_to_properties); + // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error. const ShapeSizeFunction shape_size_; - // The total number of floating point operations, transcendental operations, - // and bytes accesses (read or written) in the computation. - int64 flop_count_ = 0; - int64 transcendental_count_ = 0; - int64 bytes_accessed_ = 0; - - // Cost counts of the current instruction. These should be set by each - // handlers if different from the default values computed in Preprocess. - int64 current_flop_count_; - int64 current_transcendental_count_; - int64 current_bytes_accessed_; - - // Mapping from HLO instructions to the cost we computed for them in the - // course of the graph analysis. - std::map hlo_to_flop_count_; - std::map hlo_to_transcendental_count_; - std::map hlo_to_bytes_accessed_; + HloToProperties hlo_properties_; + + // If true, the time taken will be computed from the rates for each property + // and the total time will be the maximum time, which is the time of the + // bottleneck. + bool current_should_compute_bottleneck_time_; + + // The properties of the currently visited instruction. A HandleFoo method can + // modify these to change the default values computed in Preprocess. + Properties current_properties_; + + // The sum of the properties of all HLOs in the computation. + Properties properties_sum_; + + // How much of each property can be processed per second. E.g. if the property + // is bytes accessed, this is the number of bytes that can be processed per + // second. Is empty if no rates have been set. + Properties per_second_rates_; TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis); }; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index b74c7eb4e074bd8f340137066b6d9675bb32cee1..0a288a77ada840451915561b4b0865785b39ade7 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/compiler/xla/statusor.h" @@ -329,51 +330,67 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count()); } -using FusionCostAnalysis = ::testing::Test; +using FusionCostAnalysis = HloTestBase; TEST_F(FusionCostAnalysis, LoopFusion) { - Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); - - // Fuse all instructions in complicated expression: - // - // add = Add(C1, C2) - // clamp = Clamp(C2, add, add) - // exp = Exp(add) - // mul = Mul(exp, C3) - // sub = Sub(mul, clamp) - // tuple = Tuple({sub, sub, mul, C1}) - auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( - /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)); - auto c2 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( - /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)); - auto c3 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( - /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)); - - auto add = - HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(), c2.get()); - auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2.get(), - add.get(), add.get()); - auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get()); - auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, - exp.get(), c3.get()); - auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, - mul.get(), clamp.get()); - auto tuple = - HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()}); - - auto fusion = HloInstruction::CreateFusion( - r2f32, HloInstruction::FusionKind::kLoop, tuple.get()); - fusion->FuseInstruction(sub.get()); - fusion->FuseInstruction(mul.get()); - fusion->FuseInstruction(exp.get()); - fusion->FuseInstruction(clamp.get()); - fusion->FuseInstruction(add.get()); - - HloCostAnalysis fusion_analysis(ShapeSize); - ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); - - EXPECT_EQ(fusion_analysis.flop_count(), 16); - EXPECT_EQ(fusion_analysis.transcendental_count(), 4); + // Do this 4 times with different per-second rates to test the computation of + // bottleneck time on fusion nodes. + for (int i = 0; i < 4; ++i) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + + // Fuse all instructions in complicated expression: + // + // add = Add(C1, C2) + // clamp = Clamp(C2, add, add) + // exp = Exp(add) + // mul = Mul(exp, C3) + // sub = Sub(mul, clamp) + // tuple = Tuple({sub, sub, mul, C1}) + HloComputation::Builder builder(TestName()); + auto c1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2))); + auto c2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2))); + auto c3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2)); + auto clamp = builder.AddInstruction( + HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); + auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); + + // The time given these rates at i == 0 is exactly even among the properties + // at 1.0 seconds. For other values, one of the rates is slower so that it + // becomes the bottleneck. + HloCostAnalysis fusion_analysis(ShapeSize); + fusion_analysis.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0)); + fusion_analysis.set_transcendentals_per_second(4 * + (i == 2 ? 1 / 4.0 : 1.0)); + fusion_analysis.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0)); + ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); + + EXPECT_EQ(fusion_analysis.flop_count(), 16); + EXPECT_EQ(fusion_analysis.transcendental_count(), 4); + constexpr int64 bytes_accessed = sizeof(float) * 4 * 2 * 2; + static_assert(bytes_accessed == 64, ""); + EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed); + + EXPECT_EQ(fusion_analysis.seconds(), 1 << i); + } } TEST_F(FusionCostAnalysis, NoLayout) { @@ -382,19 +399,21 @@ TEST_F(FusionCostAnalysis, NoLayout) { Shape shape_without_layout = shape_with_layout; shape_without_layout.clear_layout(); - auto c1 = HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5))); - auto c2 = - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); - - auto broadcast = - HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1}); - auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd, - c1.get(), broadcast.get()); - - auto fusion = HloInstruction::CreateFusion( - shape_with_layout, HloInstruction::FusionKind::kLoop, add.get()); - fusion->FuseInstruction(broadcast.get()); + HloComputation::Builder builder(TestName()); + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR4FromArray4D(Array4D(2, 3, 4, 5)))); + auto c2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(shape_without_layout, c2, {1})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + shape_with_layout, HloOpcode::kAdd, c1, broadcast)); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {add, broadcast}, HloInstruction::FusionKind::kLoop); HloCostAnalysis fusion_analysis(ShapeSize); ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 4c6af5c40fa563d1c656eb152819e454aae5fb69..690c084efb131e9b075ced17bfcd0b23a23218f1 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -68,7 +68,7 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (LiteralUtil::Equal(instruction->literal(), it->second->literal())) { + if (instruction->literal().Equal(it->second->literal())) { match = it->second; break; } @@ -92,6 +92,9 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { StatusOr HloCSE::Run(HloModule* module) { bool changed = false; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } changed |= CombineConstants(computation.get(), is_layout_sensitive_); std::list post_order = diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index cc39c3ac20396f9648b5d325933aad819275b2a6..8b0b9c8bbd0cf442149b32a4539277b2daeed90e 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -51,9 +51,9 @@ TEST_F(HloCseTest, CombineTwoConstants) { // Test that two identical constants are commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -67,10 +67,10 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = computation->instructions().begin()->get(); - EXPECT_EQ(42.0f, LiteralUtil::Get(constant->literal(), {})); + EXPECT_EQ(42.0f, constant->literal().Get({})); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR0(84.0); + auto expected = Literal::CreateR0(84.0); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -102,7 +102,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -132,7 +132,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -141,20 +141,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // commoned. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); // Duplicate the float constant to verify something happens. builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -171,13 +171,13 @@ TEST_F(HloCseTest, NonscalarConstants) { // Test that identical nonscalar constants are merged. auto builder = HloComputation::Builder(TestName()); auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); // Create a constant which has the same shape but a different value. auto uncommon_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); + Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); // Tie the constants together with a tuple. This makes it easier to refer to // the constant instructions via their use. @@ -206,7 +206,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test that three identical instructions are commoned. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -236,7 +236,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { // commoned if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -267,7 +267,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { // the pass is layout insensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -311,7 +311,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { // The *1 instructions should be merged with the *2 instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNegate, constant)); @@ -349,9 +349,9 @@ TEST_F(HloCseTest, DoNotCombineRng) { // Test that two RNG ops are not commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); @@ -392,9 +392,9 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName() + "_rng_fun"); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto rng = builder.AddInstruction(HloInstruction::CreateRng( scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); auto param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -409,7 +409,7 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({5.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({5.0f}))); auto rng1 = builder.AddInstruction( HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); auto rng2 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d1b87256445e4fd51134a66666e5736baf272c71..92548dfaf0bf12755053bfe26d4cb2ae0459dd37 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -16,14 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include -#include #include -#include #include #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -35,7 +32,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -43,209 +39,6 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -string HloLocation::ToString() const { - string index_str = - ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; - return StrCat(instruction->FullyQualifiedName(), index_str); -} - -std::ostream& operator<<(std::ostream& out, const HloLocation& location) { - out << location.ToString(); - return out; -} - -string HloUse::ToString() const { - string index_str = - ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) - ? (" " + operand_index.ToString()) - : ""; - return StrCat(instruction->FullyQualifiedName(), ", operand ", operand_number, - index_str); -} - -std::ostream& operator<<(std::ostream& out, const HloUse& use) { - out << use.ToString(); - return out; -} - -HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, - const ShapeIndex& index, bool is_phi) - : id_(id), is_phi_(is_phi) { - // The defining location is always the first element in the locations_ vector. - AddLocation(instruction, index); -} - -bool HloValue::operator==(const HloValue& other) const { - bool equal = instruction() == other.instruction() && index() == other.index(); - // If the values are equal they most both be phi (or non phi). - CHECK(!(equal && is_phi() != other.is_phi())); - return equal; -} - -bool HloValue::operator!=(const HloValue& other) const { - return !(*this == other); -} - -string HloValue::ToShortString() const { - string index_str = - ShapeUtil::IsTuple(instruction()->shape()) ? index().ToString() : ""; - return StrCat(is_phi_ ? "PHI " : "", instruction()->FullyQualifiedName(), - index_str); -} - -string HloValue::ToString(int indent) const { - string indentation(indent, ' '); - string out = StrCat(indentation, ToShortString(), ", locations:\n"); - for (const HloLocation& location : locations()) { - StrAppend(&out, indentation, " ", location.ToString(), "\n"); - } - StrAppend(&out, indentation, " uses:\n"); - for (const HloUse& use : uses()) { - StrAppend(&out, indentation, " ", use.ToString(), "\n"); - } - return out; -} - -void HloValue::AddLocation(HloInstruction* instruction, - const ShapeIndex& index) { - // The given location should not already exist in locations_. - for (const HloLocation& location : locations_) { - DCHECK(!(location.instruction == instruction && location.index == index)); - } - - locations_.push_back(HloLocation{instruction, index}); - - // Update uses. - for (HloInstruction* user : instruction->users()) { - for (int64 operand_number : user->OperandIndices(instruction)) { - if (!DoesNotUseOperandBuffer(instruction, index, user)) { - for (const HloUse& use : uses_) { - // Verify that this use does not already exist. - DCHECK(!(use.instruction == user && - use.operand_number == operand_number && - use.operand_index == index)); - } - - uses_.push_back(HloUse{user, operand_number, index}); - } - } - } - - // Update liveout status of this HloValue. - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - } -} - -void HloValue::RemoveLocation(HloInstruction* instruction, - const ShapeIndex& index) { - // The defining location cannot be removed. - CHECK(!(instruction == this->instruction() && index == this->index())); - - int64 size_before = locations_.size(); - locations_.erase( - std::remove_if(locations_.begin(), locations_.end(), - [instruction, &index](const HloLocation& location) { - return location.instruction == instruction && - location.index == index; - }), - locations_.end()); - // Only a single location should have been removed. - CHECK_EQ(locations_.size(), size_before - 1); - - // Update uses which referred to this location. - uses_.erase(std::remove_if(uses_.begin(), uses_.end(), - [instruction, &index](const HloUse& use) { - return use.instruction->operand( - use.operand_number) == instruction && - use.operand_index == index; - }), - uses_.end()); - - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - // Value has been removed from a location in the entry root instruction. - // Check if the value is still live out of the module by walking all - // remaining locations. - live_out_of_module_ = false; - for (const HloLocation& location : locations()) { - if (location.instruction == - module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - break; - } - } - } -} - -std::ostream& operator<<(std::ostream& out, const HloValue& value) { - out << value.ToShortString(); - return out; -} - -void HloValueSet::SortAndUniquifyValues() { - std::sort(value_ids_.begin(), value_ids_.end()); - value_ids_.erase(std::unique(value_ids_.begin(), value_ids_.end()), - value_ids_.end()); -} - -string HloValueSet::ToString() const { - return StrCat("HloValueSet: ", tensorflow::str_util::Join(value_ids_, ", ")); -} - -/*static */ -HloValueSet HloValueSet::Union( - tensorflow::gtl::ArraySlice inputs) { - HloValueSet union_set; - for (const HloValueSet* input : inputs) { - for (HloValue::Id value_id : input->value_ids()) { - union_set.value_ids_.push_back(value_id); - } - } - union_set.SortAndUniquifyValues(); - return union_set; -} - -std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { - out << value_set.ToString(); - return out; -} - -InstructionValueSet InstructionValueSet::Union( - tensorflow::gtl::ArraySlice inputs) { - CHECK_GT(inputs.size(), 0); - for (int i = 1; i < inputs.size(); ++i) { - CHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); - } - InstructionValueSet union_set(inputs[0]->shape()); - union_set.ForEachMutableElement( - [&inputs](const ShapeIndex& index, HloValueSet* value_set) { - std::vector input_sets; - for (const InstructionValueSet* input : inputs) { - input_sets.push_back(&input->element(index)); - } - *value_set = HloValueSet::Union(input_sets); - }); - return union_set; -} - -std::ostream& operator<<(std::ostream& out, - const InstructionValueSet& instruction_value_set) { - out << instruction_value_set.ToString(); - return out; -} - -string InstructionValueSet::ToString() const { - string out = - StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); - ForEachElement([this, &out](const ShapeIndex& index, - const HloValueSet& value_set) { - StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); - }); - return out; -} - HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form, bool bitcast_defines_value) : module_(module), @@ -256,10 +49,10 @@ HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form, bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { const HloValueSet& value_set = GetValueSet(instruction, index); - if (value_set.value_ids().size() != 1) { + if (value_set.values().size() != 1) { return false; } - return GetValue(value_set.GetUniqueValueId()).instruction() == instruction; + return value_set.GetUniqueValue().defining_instruction() == instruction; } const HloValue& HloDataflowAnalysis::GetValueDefinedAt( @@ -274,20 +67,20 @@ HloValue& HloDataflowAnalysis::GetValueDefinedAt( return GetUniqueValueAt(instruction, index); } -HloValue::Id HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, - const ShapeIndex& index, - bool is_phi) { - int64 value_id = next_value_id_++; - auto it_added = values_.emplace( +HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, + const ShapeIndex& index, + bool is_phi) { + const int64 value_id = next_value_id_++; + auto emplaced = values_.emplace( std::piecewise_construct, std::forward_as_tuple(value_id), std::forward_as_tuple(value_id, instruction, index, is_phi)); - CHECK(it_added.second); + CHECK(emplaced.second); // Clear the vector of values as it is now stale. It will be lazily // reconstructed if needed when HloDataflowAnalysis::values() is called. values_vector_.clear(); - return value_id; + return &emplaced.first->second; } void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { @@ -305,16 +98,16 @@ string HloDataflowAnalysis::ToString() const { module_->computations()) { for (const std::unique_ptr& instruction : computation->instructions()) { - StrAppend(&out, " ", instruction->FullyQualifiedName(), ":\n"); + StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { GetInstructionValueSet(instruction.get()) .ForEachElement([this, &instruction, &out]( const ShapeIndex& index, const HloValueSet& value_set) { StrAppend(&out, " tuple index ", index.ToString(), ":\n"); - for (HloValue::Id value_id : value_set.value_ids()) { + for (const HloValue* value : value_set.values()) { StrAppend( - &out, " ", GetValue(value_id).ToShortString(), + &out, " ", value->ToShortString(), ValueIsDefinedAt(instruction.get(), index) ? " (def)" : "", "\n"); } @@ -322,8 +115,8 @@ string HloDataflowAnalysis::ToString() const { } else { const HloValueSet& top_level_value_set = GetValueSet(instruction.get(), /*index=*/{}); - for (HloValue::Id value_id : top_level_value_set.value_ids()) { - StrAppend(&out, " ", GetValue(value_id).ToShortString(), + for (const HloValue* value : top_level_value_set.values()) { + StrAppend(&out, " ", value->ToShortString(), ValueIsDefinedAt(instruction.get()) ? " (def)" : "", "\n"); } } @@ -361,9 +154,8 @@ const std::vector& HloDataflowAnalysis::values() const { for (auto& pair : values_) { values_vector_.push_back(&pair.second); } - std::sort( - values_vector_.begin(), values_vector_.end(), - [](const HloValue* a, const HloValue* b) { return a->id() < b->id(); }); + std::sort(values_vector_.begin(), values_vector_.end(), + HloValue::IdLessThan); } else { CHECK_EQ(values_vector_.size(), values_.size()); for (const HloValue* value : values_vector_) { @@ -405,8 +197,8 @@ InstructionValueSet HloDataflowAnalysis::Phi( // Construct a vector of unique value IDs of the inputs. std::vector input_value_ids; for (const InstructionValueSet* input : inputs) { - for (HloValue::Id value_id : input->element(index).value_ids()) { - input_value_ids.push_back(value_id); + for (const HloValue* value : input->element(index).values()) { + input_value_ids.push_back(value->id()); } } std::sort(input_value_ids.begin(), input_value_ids.end()); @@ -427,7 +219,7 @@ InstructionValueSet HloDataflowAnalysis::Phi( if (input_value_ids.size() <= 1) { if (input_value_ids.size() == 1) { - *value_set = HloValueSet({input_value_ids[0]}); + *value_set = HloValueSet({&GetValue(input_value_ids[0])}); } if (existing_phi_value) { // The merge point does not have multiple distinct inputs (which are @@ -442,7 +234,7 @@ InstructionValueSet HloDataflowAnalysis::Phi( if (existing_phi_value) { // A phi value already exists so reuse it in the new // InstructionValueSet. - *value_set = HloValueSet({existing_phi_value->id()}); + *value_set = HloValueSet({existing_phi_value}); } else { // Create a new phi value. *value_set = @@ -453,39 +245,37 @@ InstructionValueSet HloDataflowAnalysis::Phi( return new_value_set; } -void HloDataflowAnalysis::UpdateLocationsOfValuesAt( +void HloDataflowAnalysis::UpdatePositionsOfValuesAt( HloInstruction* instruction, const InstructionValueSet& new_value_set, const InstructionValueSet* prev_value_set) { if (prev_value_set != nullptr) { - // Remove locations from the old value set. + // Remove positions from the old value set. prev_value_set->ForEachElement( [this, instruction](const ShapeIndex& index, const HloValueSet& value_set) { - for (HloValue::Id value_id : value_set.value_ids()) { + for (const HloValue* value : value_set.values()) { // HloValues in the previous value set may have been deleted. - if (!ContainsKey(values_, value_id)) { + if (!ContainsKey(values_, value->id())) { continue; } - // Don't remove the defining location of the value. - HloValue& value = GetValue(value_id); - if (instruction == value.instruction()) { - CHECK_EQ(index, value.index()); + // Don't remove the defining position of the value. + if (instruction == value->defining_instruction()) { + CHECK_EQ(index, value->defining_index()); } else { - value.RemoveLocation(instruction, index); + GetValue(value->id()).RemovePosition(instruction, index); } } }); } - // Add locations in the new value set. + // Add positions in the new value set. new_value_set.ForEachElement( [this, instruction](const ShapeIndex& index, const HloValueSet& value_set) { - for (HloValue::Id value_id : value_set.value_ids()) { - HloValue& value = GetValue(value_id); - if (instruction == value.instruction()) { - CHECK_EQ(index, value.index()); + for (const HloValue* value : value_set.values()) { + if (instruction == value->defining_instruction()) { + CHECK_EQ(index, value->defining_index()); } else { - value.AddLocation(instruction, index); + GetValue(value->id()).AddPosition(instruction, index); } } }); @@ -672,7 +462,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate( // Update uses. First clear all of the old uses at the particular // operands. Then add the new uses. There may be overlap between the old // uses and new uses. - UpdateLocationsOfValuesAt(instruction, GetInstructionValueSet(instruction), + UpdatePositionsOfValuesAt(instruction, GetInstructionValueSet(instruction), &old_value); } } @@ -694,15 +484,24 @@ InstructionValueSet HloDataflowAnalysis::RecomputeParameterValueSet( std::vector inputs; bool called_from_while = false; for (const CallSite& callsite : call_graph_node.caller_callsites()) { - inputs.push_back(&GetInstructionValueSet( - callsite.instruction()->operand(parameter->parameter_number()))); - if (callsite.instruction()->opcode() == HloOpcode::kWhile) { - // In a while instruction, the backedge is also a dataflow input to the - // parameter instruction. This code covers the case where the parameter is - // in the while body or the parameter is in the while condition. + if (callsite.instruction()->opcode() == HloOpcode::kCall) { + // The operand values of a call instruction are forwarded to the + // respective parameter instruction of the subcomputation. + inputs.push_back(&GetInstructionValueSet( + callsite.instruction()->operand(parameter->parameter_number()))); + } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + // In a while instruction, the while operand (ie, the init value) and the + // backedge are dataflow inputs to the parameter instruction. This is the + // case for parameters of both the body and condition computations. + CHECK_EQ(parameter->parameter_number(), 0); + inputs.push_back( + &GetInstructionValueSet(callsite.instruction()->operand(0))); inputs.push_back(&GetInstructionValueSet( callsite.instruction()->while_body()->root_instruction())); called_from_while = true; + } else { + LOG(FATAL) << "CallContext::kSequential computations should only be " + "called from call or while instructions"; } } @@ -797,13 +596,156 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); break; } - UpdateLocationsOfValuesAt(instruction.get(), + UpdatePositionsOfValuesAt(instruction.get(), GetInstructionValueSet(instruction.get())); } } return Status::OK(); } +bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const { + // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' + // is live into the module. + if (b.defining_instruction()->parent() == module_->entry_computation() && + b.defining_instruction()->opcode() == HloOpcode::kParameter) { + return false; + } + + // Phi values require special handling. Because XLA does not have a phi + // instruction, the definition instruction of the phis values are + // placeholders: either the subcomputation parameter (body or condition) or + // the while instruction. However, the program point where these values are + // logically defined does not necessarily coincide exactly with program point + // of these place-holder instructions. So we explicitly define the following + // order for phi values: + // + // body/condition parameter phi: + // Defined before all values defined in its computation excepting other + // phis. + // + // while phi: + // defined after all values defined in the condition or body. + // + auto is_body_or_condition_phi = [](const HloValue& v) { + return v.is_phi() && + v.defining_instruction()->opcode() == HloOpcode::kParameter; + }; + if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) && + call_graph_->InstructionIsNestedIn(b.defining_instruction(), + a.defining_instruction()->parent())) { + return true; + } + if (is_body_or_condition_phi(b) && + call_graph_->InstructionIsNestedIn(a.defining_instruction(), + b.defining_instruction()->parent())) { + return false; + } + + // If 'b' is a while phi and 'a' is in the body or condition, then 'a' + // executes before 'b'. + if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile && + (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), b.defining_instruction()->while_body()) || + call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->while_condition()))) { + return true; + } + + return ordering.ExecutesBefore(a.defining_instruction(), + b.defining_instruction()); +} + +bool HloDataflowAnalysis::UseIsBeforeValueDefinition( + const HloUse& use, const HloValue& value, + const HloOrdering& ordering) const { + if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) { + return true; + } + + // If the use is at the instruction where the value is defined, then the use + // is before the def if the instruction allows buffer sharing (in place + // computation). + if (use.instruction == value.defining_instruction() && + CanShareOperandBufferWithUser( + use.instruction->mutable_operand(use.operand_number), + use.operand_index, value.defining_instruction(), + value.defining_index())) { + return true; + } + + // The use at a while is an input to a phi, and logically occurs before values + // are defined in the body or condition computations. + if (use.instruction->opcode() == HloOpcode::kWhile) { + const HloInstruction* xla_while = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_condition())) { + return true; + } + } + + // Similarly if the value is defined at a while, it logically occurs after any + // uses in the body or condition computations. + if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { + CHECK(ssa_form_); + const HloInstruction* xla_while = value.defining_instruction(); + if (call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_condition())) { + return true; + } + } + return false; +} + +bool HloDataflowAnalysis::LiveRangeStrictlyBefore( + const HloValue& a, const HloValue& b, const HloOrdering& ordering) const { + VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() + << ", b = " << b.ToShortString() << ")"; + if (!IsDefinedBefore(a, b, ordering)) { + VLOG(4) << "a not defined before b"; + return false; + } + + // Live-out values from the module can never have ranges strictly before any + // other value. + if (a.live_out_of_module()) { + VLOG(4) << "a is live out of module"; + return false; + } + + // Live-out values of computations can never have ranges strictly before any + // other value in the computation (including values nested in + // subcomputations). + if (a.live_out_of_computation() && + call_graph_->InstructionIsNestedIn(b.defining_instruction(), + a.defining_instruction()->parent())) { + VLOG(4) << "a is live out of computation containing b"; + return false; + } + + // All uses of 'a' must be before 'b' is defined. + for (const HloUse& use : a.uses()) { + if (!UseIsBeforeValueDefinition(use, b, ordering)) { + VLOG(4) << "use of a (" << use << ") not before b is defined"; + return false; + } + } + + return true; +} + +bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const { + // Buffers without disjoint liveness may interfere. + return !LiveRangeStrictlyBefore(a, b, ordering) && + !LiveRangeStrictlyBefore(b, a, ordering); +} + /* static */ StatusOr> HloDataflowAnalysis::Run( HloModule* module, bool ssa_form, bool bitcast_defines_value) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 2f9b0a64be5a00f490e5fc678ac5589e374f80d7..4eb4f0bb16768bee9eaae8d19f578dad242dbb2e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Analysis for determining the possible set of values for all locations +// Analysis for determining the possible set of values for all positions // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped // tracking values across computation boundaries. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ -#include +#include #include #include #include @@ -28,222 +28,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" namespace xla { -// Abstraction which identifies a specific point in the XLA graph. An -// HloLocation specifies a ShapeIndex within the output of a specific -// instruction. -struct HloLocation { - HloInstruction* instruction; - ShapeIndex index; - - string ToString() const; - - bool operator==(const HloLocation& other) const { - return instruction == other.instruction && index == other.index; - } - bool operator!=(const HloLocation& other) const { return !(*this == other); } -}; - -std::ostream& operator<<(std::ostream& out, const HloLocation& location); - -// Defines a single use of an HLO value. -struct HloUse { - // Instruction at which the value is used. - HloInstruction* instruction; - - // The operand number in which the value is appears. - int64 operand_number; - - // The shape index within the operand in which the value appears. - ShapeIndex operand_index; - - string ToString() const; - - bool operator==(const HloUse& other) const { - return instruction == other.instruction && - operand_number == other.operand_number && - operand_index == other.operand_index; - } - - bool operator!=(const HloUse& other) const { return !(*this == other); } -}; - -std::ostream& operator<<(std::ostream& out, const HloUse& use); - -// Class describing a value used by the dataflow analysis. XLA arrays are -// trivially a single HloValue. Tuples are made up of more than one HloValue: an -// HloValue for the pointer vector, and an HloValue for each child element. -// -// Every HloValue is defined by a particular instruction and most instructions -// define only a single HloValue. Instructions which define a single HloValue -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single HloValue -// which is a vector of pointers to the values containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple values only the top-level HloValue (the vector of pointers) is -// defined by the Tuple instruction. The values containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one HloValue. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. These tuple-shaped instructions do -// not assemble a tuple from existing HloValues like the Tuple instruction does, -// but rather define all the HloValues in the tuple. -class HloValue { - public: - using Id = int64; - - // Construct an HloValue defined by 'instruction' at shape index 'index'. If - // is_phi is true, then this value is a phi value, for example, at the - // parameter of a while body computation. Phi values are only used in the SSA - // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). - HloValue(HloValue::Id id, HloInstruction* instruction, - const ShapeIndex& index, bool is_phi = false); - - // Return a unique identifier for this HloValue. This value is used for stable - // sorting and iteration - Id id() const { return id_; } - - // Returns whether this value is a phi value. - bool is_phi() const { return is_phi_; } - - // Return the location where this value is defined. - const HloLocation& DefinitionLocation() const { return locations_[0]; } - - // Return the instruction which defines this HloValue. - HloInstruction* instruction() const { - return DefinitionLocation().instruction; - } - - // Return the shape index at which this HloValue is defined in the output of - // instruction(). - const ShapeIndex& index() const { return DefinitionLocation().index; } - - // Add or remove a location at which the HloValue appears. The definition - // location can not be removed. The uses of the HloValue are updated. - void AddLocation(HloInstruction* instruction, const ShapeIndex& index); - void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index); - - // Return all locations of the HloValue in the module. - const std::vector& locations() const { return locations_; } - - // Return all uses of the HloValue. - const std::vector& uses() const { return uses_; } - - // Set/get whether this HloValue is live out of the module. - bool live_out_of_module() const { return live_out_of_module_; } - - bool operator==(const HloValue& other) const; - bool operator!=(const HloValue& other) const; - - // Return a single-line string representation of the value. - string ToShortString() const; - - string ToString(int indent = 0) const; - - private: - // Unique identifier for this HloValue. Used for stable sorting and iteration. - const Id id_; - - // Whether this instruction is a phi value. - const bool is_phi_; - - // The set of locations of this HloValue. The first element is always the - // location of the definition. - std::vector locations_; - - // The set of uses of this HloValue. - std::vector uses_; - - // Whether this value is live out of the HLO module. - bool live_out_of_module_ = false; -}; - -std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); - -// A class representing the possible set of HloValues at a particular point -// (shape index in the output of an instruction) in the XLA graph. This set -// contains the set of reaching HloValue definitions. For a simple array-shaped -// instruction like Add, the HloValueSet of the top-level of the instruction's -// output trivially contains only the HloValue defined by the instruction. For -// instructions which have non-trivial dataflow such as Tuple or Select, the -// HloValueSets of the instruction's output contains one or more HloValues -// defined by the instruction's operands or defined further up in the XLA graph. -class HloValueSet { - public: - HloValueSet() = default; - - explicit HloValueSet(tensorflow::gtl::ArraySlice value_ids) - : value_ids_(value_ids.begin(), value_ids.end()) { - SortAndUniquifyValues(); - } - - // Return the union of the given HloValueSets. - static HloValueSet Union( - tensorflow::gtl::ArraySlice inputs); - - // Return the vector of the IDs of all HloValues in the set. Values in the - // vector are unique and sorted. - const std::vector& value_ids() const { return value_ids_; } - - // Return the unique HLO value in the set. CHECKs if the set does not contain - // exactly one value. - HloValue::Id GetUniqueValueId() const { - CHECK_EQ(value_ids().size(), 1); - return value_ids()[0]; - } - - bool operator==(const HloValueSet& other) const { - return value_ids() == other.value_ids(); - } - bool operator!=(const HloValueSet& other) const { return !(*this == other); } - - string ToString() const; - - private: - // Sorts value_ and removes duplicates. This should be called after adding any - // elements to values_. - void SortAndUniquifyValues(); - - // HloValues sorted by HloValue::Id. - std::vector value_ids_; -}; - -std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); - -// A class collecting the HloValues which might be contained in the output of -// an HLO instruction. For array-shaped instructions, an InstructionValueSet -// trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets -// hold multiple HloValueSets. -class InstructionValueSet : public ShapeTree { - public: - InstructionValueSet(const Shape& shape) : ShapeTree(shape) {} - - // Return the union of the given InstructionValueSets. - static InstructionValueSet Union( - tensorflow::gtl::ArraySlice inputs); - - string ToString() const; -}; - -std::ostream& operator<<(std::ostream& out, - const InstructionValueSet& instruction_value_set); - // Analysis which identifies all HLO values and their uses in an HLO module. class HloDataflowAnalysis { public: @@ -298,17 +94,28 @@ class HloDataflowAnalysis { // shape index. CHECKs if the value set does not contain a exactly one value. const HloValue& GetUniqueValueAt(const HloInstruction* instruction, const ShapeIndex& index = {}) const { - return GetValue(GetValueSet(instruction, index).GetUniqueValueId()); + return GetValueSet(instruction, index).GetUniqueValue(); } HloValue& GetUniqueValueAt(const HloInstruction* instruction, const ShapeIndex& index = {}) { - return GetValue(GetValueSet(instruction, index).GetUniqueValueId()); + return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); } // Return the HloValue with the given Id. const HloValue& GetValue(HloValue::Id value_id) const; HloValue& GetValue(HloValue::Id value_id); + // Returns whether the given values interfere assuming the given HLO + // ordering. Two values interfere if they may both be simultaneously live. + bool MayInterfere(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Overload which takes HloValue:Ids. + bool MayInterfere(HloValue::Id a, HloValue::Id b, + const HloOrdering& ordering) const { + return MayInterfere(GetValue(a), GetValue(b), ordering); + } + // Return the total number of HloValues. int64 value_count() const { return values_.size(); } @@ -323,10 +130,9 @@ class HloDataflowAnalysis { HloDataflowAnalysis(HloModule* module, bool ssa_form, bool bitcast_defines_value = false); - // Creates a new HloValue defined at the given instruction and shape index and - // return its ID. - HloValue::Id NewHloValue(HloInstruction* instruction, const ShapeIndex& index, - bool is_phi = false); + // Returns a new HloValue defined at the given instruction and shape index. + HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, + bool is_phi = false); // Delete the HloValue with the given ID. void DeleteHloValue(HloValue::Id value_id); @@ -363,24 +169,40 @@ class HloDataflowAnalysis { tensorflow::gtl::ArraySlice inputs, bool skip_top_level = false); - // Updates the locations of the HloValues in the output of the given + // Updates the positions of the HloValues in the output of the given // instruction. This should be called after the instruction value set of // 'instruction' has been changed. 'prev_value_set' must point to the previous // state of the value set prior to the change. 'prev_value_set' may be null if - // this is the first time locations are being computed. The previous state is - // necessary to efficiently remove locations which have been eliminated due to + // this is the first time positions are being computed. The previous state is + // necessary to efficiently remove positions which have been eliminated due to // changes in the instructions' InstructionValueSet. - void UpdateLocationsOfValuesAt( + void UpdatePositionsOfValuesAt( HloInstruction* instruction, const InstructionValueSet& new_value_set, const InstructionValueSet* prev_value_set = nullptr); + // Returns true if the live range of the given value 'a' is strictly before + // the live range of value 'b' using the given HLO ordering. + bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Returns whether the value 'a' is defined before the value 'b' under the + // given ordering. + bool IsDefinedBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Returns whether the given use is before the given value definition. + bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value, + const HloOrdering& ordering) const; + HloModule* const module_; const bool ssa_form_; const bool bitcast_defines_value_; std::unique_ptr call_graph_; - // The map of all HloValues in the module. + // The map of all HloValues in the module. We pass around pointers to the + // mapped HloValues, so the underlying container must keep them valid despite + // mutations touching other map entries. std::unordered_map values_; // A map from instruction to InstructionValueSet. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 21344af5f224843a857984162a36b8a09915e607..2b685e355f0ce4d856639c29c3b1b254b068ef7b 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -39,46 +39,58 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(TestName()) {} + HloDataflowAnalysisTest() : module_(CreateNewModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. const HloDataflowAnalysis& RunAnalysis(bool ssa_form, bool bitcast_defines_value = false) { analysis_ = - HloDataflowAnalysis::Run(&module_, ssa_form, bitcast_defines_value) + HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); return *analysis_; } - // Return a vector of the HloValues at the given program location. + // Return a vector of the HloValues at the given program position. std::vector HloValuesAt(const HloInstruction* instruction, const ShapeIndex& index = {}) { CHECK(analysis_ != nullptr); std::vector values; - for (HloValue::Id value_id : - analysis_->GetValueSet(instruction, index).value_ids()) { - values.push_back(analysis_->GetValue(value_id)); + for (const HloValue* value : + analysis_->GetValueSet(instruction, index).values()) { + values.push_back(*value); } return values; } - HloModule module_; + // Returns true if the top-level values for instructions 'a' and 'b' may + // interfere. Precondition: 'a' and 'b' define array-shaped values. + bool InstructionsMayInterfere(const HloOrdering& ordering, + const HloInstruction* a, + const HloInstruction* b) { + EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); + EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a), + analysis_->GetValueDefinedAt(b), ordering); + } + + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); + const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { // Test the dataflow for a simple binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -89,14 +101,14 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) { EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); - // Verify the locations of the values. These locations are all trivial because + // Verify the positions of the values. These positions are all trivial because // there are no instructions which forward values. - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).locations(), - UnorderedElementsAre(HloLocation{constant1, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant2).locations(), - UnorderedElementsAre(HloLocation{constant2, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(add).locations(), - UnorderedElementsAre(HloLocation{add, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(), + UnorderedElementsAre(HloPosition{constant1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(), + UnorderedElementsAre(HloPosition{constant2, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(), + UnorderedElementsAre(HloPosition{add, {}})); // Verify the uses of the values. EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), @@ -126,7 +138,7 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -143,42 +155,36 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1)); EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); - // Verify the locations of the values. + // Verify the positions of the values. EXPECT_THAT( - analysis.GetValueDefinedAt(param0).locations(), - UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, - HloLocation{gte0, {}})); + analysis.GetValueDefinedAt(param0).positions(), + UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}}, + HloPosition{gte0, {}})); EXPECT_THAT( - analysis.GetValueDefinedAt(param1).locations(), - UnorderedElementsAre(HloLocation{param1, {}}, HloLocation{tuple, {1}}, - HloLocation{gte1, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(tuple).locations(), - UnorderedElementsAre(HloLocation{tuple, {}})); + analysis.GetValueDefinedAt(param1).positions(), + UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}}, + HloPosition{gte1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(), + UnorderedElementsAre(HloPosition{tuple, {}})); // Verify uses. Of interest is that a GetTupleElement instruction is only a // use of the top-level value in the tuple operand. EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(), - UnorderedElementsAre(HloUse{tuple, 0, {}}, HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(), - UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } TEST_P(HloDataflowAnalysisTest, NestedTuple) { - // Verify the dataflow through a nested tuple of the following form for two - // constants %constant1 and %constant2: - // - // %nested_tuple = {{%constant1, %constant2}, - // {%constant1, %constant2}, - // %constant1} - // + // Verify the dataflow through a nested tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto nested_tuple = builder.AddInstruction( @@ -187,33 +193,30 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1)); auto gte_out = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); EXPECT_EQ(analysis.values().size(), 4); - // Verify locations and uses. + // Verify positions and uses. EXPECT_THAT( - analysis.GetValueDefinedAt(constant1).locations(), + analysis.GetValueDefinedAt(constant1).positions(), UnorderedElementsAre( - HloLocation{constant1, {}}, HloLocation{tuple, {0}}, - HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}}, - HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}}, - HloLocation{gte_out, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre( - HloUse{tuple, 0, {}}, HloUse{nested_tuple, 0, {0}}, - HloUse{nested_tuple, 1, {0}}, HloUse{nested_tuple, 2, {}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{nested_tuple, 0, {1}}, - HloUse{nested_tuple, 1, {1}})); + HloPosition{constant1, {}}, HloPosition{tuple, {0}}, + HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}}, + HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}}, + HloPosition{gte_out, {}})); + // Constant values should have no uses though one is live out. The positions + // where they appear as operands are on instructions which do not use the + // values (eg, Tuple). + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + + // The top-level tuple values are used in GTE instructions. EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), - UnorderedElementsAre(HloUse{nested_tuple, 0, {}}, - HloUse{nested_tuple, 1, {}}, - HloUse{gte_out, 0, {}})); + UnorderedElementsAre(HloUse{gte_out, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(), UnorderedElementsAre(HloUse{gte_tuple, 0, {}})); @@ -236,16 +239,16 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -268,11 +271,12 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { @@ -285,20 +289,20 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kSubtract, call1, call2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -316,17 +320,18 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call1, 0, {}}, - HloUse{call2, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call1, 1, {}}, - HloUse{call2, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); // The Add from the subcomputation is used as both operands of the Subtract. EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { @@ -339,18 +344,18 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -392,7 +397,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1)); HloComputation* inner_computation = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); auto outer_builder = HloComputation::Builder("OuterComputation"); auto outer_param0 = outer_builder.AddInstruction( @@ -400,19 +405,19 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto outer_param1 = outer_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape_, "param1")); // Swizzle parameters. - auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( + outer_builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {outer_param1, outer_param0}, inner_computation)); HloComputation* outer_computation = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); - auto call = builder.AddInstruction(HloInstruction::CreateCall( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -423,14 +428,10 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { // Verify that the uses of the constants are properly swizzled by parameter // permutation in nested_call. - EXPECT_THAT( - analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, - HloUse{add, 1, {}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, - HloUse{add, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } @@ -465,33 +466,37 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); - auto body_tuple = body_builder.AddInstruction( + body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); // Condition computation trivially returns a constant "false". auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + EXPECT_TRUE( + analysis.GetValueDefinedAt(cond_constant).live_out_of_computation()); + EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); + if (ssa_form) { // Element 0 of the tuple passed through the body so no phi value is // defined. @@ -507,15 +512,17 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi()); - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{tuple, 0, {}}, - HloUse{xla_while, 0, {0}}, - HloUse{body_tuple, 0, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}})); // Constant1 passes through the body and out of the module. EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) .live_out_of_module()); + + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); } else { // While instruction and subcomputation parameters should not define values // in non-ssa form. @@ -528,6 +535,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } } @@ -565,21 +573,21 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -588,7 +596,7 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); auto xla_while2 = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -630,9 +638,9 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); // Element 0 passes transparently through the body. auto inner_builder = HloComputation::Builder("inner_body"); @@ -647,7 +655,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { inner_builder.AddInstruction( HloInstruction::CreateTuple({inner_element_0, add})); HloComputation* inner_body = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); // Element 1 passes transparently through the body. auto outer_builder = HloComputation::Builder("outer_body"); @@ -664,18 +672,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( tuple_shape, condition, inner_body, outer_tuple)); HloComputation* outer_body = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -751,26 +759,26 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_1, body_element_0})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -817,15 +825,15 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { // Test a kSelect of an array value. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -841,15 +849,15 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { // instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -868,7 +876,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, select12, select34)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -899,31 +907,33 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { analysis.GetValueDefinedAt(constant4))); EXPECT_THAT( - analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{tuple1, 0, {}}, HloUse{select11, 1, {0}}, - HloUse{select11, 2, {0}}, HloUse{select12, 1, {0}}, - HloUse{select1234, 1, {0}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{tuple2, 0, {}}, HloUse{select12, 2, {0}}, - HloUse{select1234, 1, {0}})); + analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}}, + HloUse{select12, 1, {}})); + + // The two constant values just pass through the Selects and are not + // used. They are live out however. + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); } TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { // Test kSelect of a nested tuple. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto constant5 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant3})); auto tuple1 = builder.AddInstruction( @@ -935,7 +945,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -993,24 +1003,24 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -1024,7 +1034,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1062,11 +1072,11 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { // Test the bitcast_defines_value flag to the dataflow analysis. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); { @@ -1102,7 +1112,7 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1126,6 +1136,352 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); } +TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { + // A simple chain of elementwise operations. No values should interfere. + // + // param --> negate -> exp -> log + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp)); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + // No values should interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp)); + + // Values should interfere with itself. + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp)); +} + +TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { + // Two entry params, which interfere with each other. + // + // param0 --> negate ---------------\ + // param1 --> exp --> add + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, vector_shape_, "param1")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + auto entry = module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param0, negate, param1, exp, add}}); + SequentialHloOrdering ordering(module_.get(), sequence); + + // Entry parameters interfere as if they are defined simultaneously at + // the very beginning. + EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add)); + + // Negate and exp still interfere. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + + // But {negate, add} and {exp, add} don't interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { + // Similar to MultipleEntryParameters_Sequential, but the parameter is of + // while body computation. Body computation in the sequential order: + // + // %constant = Constant(...) + // %exp = Exp(%constant) + // %param = Param(0) + // %add = Add(%param, %exp) ;; Root of body + // %dead_constant = Constant(...) + // %dead_negate = Negate(%dead_constant) + // + // %constant and its only use %exp are ordered before 'param'. However, the + // %constant and %param values still interfere because the parameter is + // considered live into the while body. + // + // Similarly, %dead_constant and %dead_negate are ordered after the root of + // the body computation %add. However, %add is liveout of the computation so + // %dead_constant and %add interfere. + auto body_builder = HloComputation::Builder(TestName()); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "body_param")); + auto constant = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto exp = body_builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, exp, body_param)); + auto dead_constant = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, dead_constant)); + HloComputation* body = module_->AddEmbeddedComputation( + body_builder.Build(/*root_instruction=*/add)); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "cond_param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param)); + + auto entry = module_->AddEntryComputation(builder.Build()); + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param, xla_while}}); + sequence.insert({condition, {cond_param, cond_constant}}); + // Construct the order such that 'constant' and its use 'exp' are before + // body_param. + sequence.insert({body, {constant, exp, body_param, add}}); + + SequentialHloOrdering ordering(module_.get(), sequence); + + // 'add' is the body root even though later instructions follow in the order + // like 'dead_negate'. Only 'add' should be live out of the computation. + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_FALSE( + analysis.GetValueDefinedAt(dead_negate).live_out_of_computation()); + + // 'add' is live out of the body and will interfere with an later instructions + // such as 'dead_constant' and 'dead_negate'. + EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate)); + + // The remaining checks test phi values defined by body and condition + // parameters which only occur in the SSA form of the analysis. + if (ssa_form) { + // Though the ordering suggests 'constant' and 'param' should not interfere, + // 'param' is live in and thus interferes with any earlier instruction of + // the computation in the order (eg 'constant')' + EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add)); + + // The following values end up in the same buffer: + // (1) the init value: 'param' + // (2) the body parameter: 'body_param' + // (3) the condition parameter: 'cond_param' + // (4) the root value of the while body: 'add' + // (5) the while value: 'xla_while' + // None should interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while)); + } +} + +TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) { + // A chain of operations with two elementwise and one non-elementwise. The + // elementwise op should not interfere with its operand, while the + // non-elementwise op should interfere. Entry params always interfere. + // + // param --> exp -> negate -> reverse + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp)); + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(vector_shape_, negate, {0})); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse)); + + // Negate is elementwise, so doesn't interfere with its operand. + // Reverse is non-elementwise, so does interfere with its operand. + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse)); +} + +TEST_P(HloDataflowAnalysisTest, OverlappedValues) { + // Verify simultaneously live values interfere (exp and negate). + // + // param --> negate -> add + // \---> exp -----/ + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { + // Identical to the test OverlappedValue but using a sequential ordering of + // HLO instructions. + // + // param --> negate -> add + // \---> exp -----/ + // + // Sequential order: + // param, negate, exp, add + // + // Liveness is identical to the DependencyHloOrdering. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + auto entry = module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + SequentialHloOrdering::HloModuleSequence sequence; + std::vector order = {param, negate, exp, add}; + sequence.emplace(entry, order); + + SequentialHloOrdering ordering(module_.get(), sequence); + + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { + // Test MayInterfere() for embedded computation, specifically the interference + // of values in different computations. + // + // embedded_computation: + // %embedded_param = Param(0) + // %embedded_log = Log(%embedded_param) + // + // entry computation: + // %param = Param(0) + // %negate = Negate(%param) + // %exp = Negate(%exp) + // %call = Call(embedded_computation, {%exp}) + // %add = Add(%negate, %call) + // + // Note %negate is live across the call and should interfere with all values + // in the embedded computation. + auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); + auto embedded_param = embedded_builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "embedded_param")); + auto embedded_log = + embedded_builder.AddInstruction(HloInstruction::CreateUnary( + vector_shape_, HloOpcode::kLog, embedded_param)); + auto embedded_computation = + module_->AddEmbeddedComputation(embedded_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation)); + builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, call)); + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + // Exp only use is the call so it should not interfere with values inside the + // embedded computation. + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log)); + + // Negate is live across the call and should interfere with values in the + // embedded computation + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 3755b9e4c005c5e50b149d8dc8c51363eb111868..5b2c57da4ff3a1f887f777c3304893d950b3d3a9 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -38,6 +38,9 @@ StatusOr HloDCE::Run(HloModule* module) { bool changed = false; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( [&live_instructions](HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 10cd7ca7c0990ab553c865da01b00475382316e2..704b8dfca700f7c4a00689593aea9743de1f817c 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -45,9 +45,9 @@ TEST_F(HloDceTest, NoDeadCode) { // Verify that no dead code is removed from a computation with no dead code. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -98,9 +98,9 @@ TEST_F(HloDceTest, ControlDependencies) { // Verify that instructions with control dependencies are not removed. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); // Create two dead instructions: a negate and an add. auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3e7f5b1f3d97ace48fbc22b224667acebcc52093..a0c5cbe916050a8aa7849c3e37daad70bc8d6190 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -31,11 +31,13 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -89,11 +91,11 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = LiteralUtil::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index)); + auto result = Literal::CreateFromShape(shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); @@ -117,12 +119,11 @@ StatusOr> ElementWiseUnaryOpImpl( ShapeUtil::HumanString(operand->shape()).c_str()); } - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return unary_op( - LiteralUtil::Get(operand_literal, multi_index)); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); } @@ -168,6 +169,23 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleAbs(abs, operand); }; + Status HandleBroadcast(HloInstruction* broadcast) override { + parent_->evaluated_[broadcast] = + Literal::CreateFromShape(broadcast->shape()); + auto output = parent_->evaluated_[broadcast].get(); + auto operand_to_broadcast = + parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); + std::vector broadcast_indices( + ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); + return output->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; + } + return operand_to_broadcast.Get(broadcast_indices); + }); + } + Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { @@ -176,7 +194,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override { + Status HandleCopy(HloInstruction* copy) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy], ElementWiseUnaryOp(copy, [](ReturnT elem_operand) { return elem_operand; @@ -184,42 +202,19 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - template - std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { - DCHECK_EQ(src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< - typename primitive_util::PrimitiveTypeToNative::type, - typename primitive_util::PrimitiveTypeToNative::type>( - src_literal); - } - - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override { - auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); - - switch (operand->shape().element_type()) { -#define CONVERT_IF_TYPES_MATCH(src_type) \ - case (src_type): \ - parent_->evaluated_[convert] = LiteralUtil::Convert< \ - typename primitive_util::PrimitiveTypeToNative::type, \ - ReturnT>(operand_literal); \ - break; - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "unimplemented operand type for HandleCovert: " - << PrimitiveType_Name(operand->shape().element_type()); + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).Convert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); } - return Status::OK(); } @@ -322,8 +317,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMaximum(HloInstruction* maximum) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { @@ -332,8 +326,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMinimum(HloInstruction* minimum) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { @@ -409,6 +402,258 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override { + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), conv->shape())); + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs->shape())); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs->shape())); + + const auto& dnums = conv->convolution_dimension_numbers(); + const int64 num_spatial_dims = dnums.spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); + CHECK_GE(num_spatial_dims, 1); + CHECK_EQ(window.dimensions_size(), num_spatial_dims); + + CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(lhs->shape())); + CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(rhs->shape())); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), window, dnums)); + CHECK(ShapeUtil::Compatible(conv->shape(), inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(conv->shape()) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + + // Dimension number applicable for both input (lhs), and output. + const int64 batch_dim = dnums.batch_dimension(); + const int64 z_dim = dnums.feature_dimension(); + // Dimension number applicable for kernel (rhs). + const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); + const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + + const int64 z_size = ShapeUtil::GetDimension(lhs->shape(), z_dim); + + std::vector window_dimension_sizes; + for (auto i : dnums.kernel_spatial_dimensions()) { + window_dimension_sizes.push_back( + ShapeUtil::GetDimension(rhs->shape(), i)); + } + + const Shape& window_shape = ShapeUtil::MakeShape( + rhs->shape().element_type(), window_dimension_sizes); + + auto result = Literal::CreateFromShape(conv->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice out_index) { + ReturnT result_val = static_cast(0); + + std::vector lhs_index(lhs_rank, 0); + std::vector rhs_index(rhs_rank, 0); + + lhs_index[batch_dim] = out_index[batch_dim]; + rhs_index[kernel_output_z_dim] = out_index[z_dim]; + + std::vector rhs_spatial_index( + dnums.kernel_spatial_dimensions_size(), 0); + + // Convolve input feature with kernel. + do { + for (int64 iz = 0; iz < z_size; ++iz) { + lhs_index[z_dim] = iz; + rhs_index[kernel_input_z_dim] = iz; + + // Find corresponding spatial dimension index for input (lhs). + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 spatial_dim = dnums.spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const int64 undilated_index = + out_index[spatial_dim] * window.dimensions(ki).stride() - + window.dimensions(ki).padding_low() + + rhs_spatial_index[ki] * + window.dimensions(ki).window_dilation(); + // Skip if the lhs (input) index is to be dilated. + if (undilated_index % window.dimensions(ki).base_dilation() != + 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. + lhs_index[spatial_dim] = + undilated_index / window.dimensions(ki).base_dilation(); + + // Skip if input index is not in bound. + if (!(lhs_index[spatial_dim] >= 0 && + lhs_index[spatial_dim] < + lhs->shape().dimensions(spatial_dim))) { + goto cnt; + } + + rhs_index[dnums.kernel_spatial_dimensions(ki)] = + rhs_spatial_index[ki]; + } + + result_val += lhs_literal.Get(lhs_index) * + rhs_literal.Get(rhs_index); + } + cnt:; + } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + + return result_val; + })); + + parent_->evaluated_[conv] = std::move(result); + return Status::OK(); + }; + + Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) override { + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + // Dot only supports operands of rank 1 and 2. + const auto dot_rank = ShapeUtil::Rank(dot->shape()); + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + CHECK(lhs_rank > 0 && lhs_rank <= 2); + CHECK(rhs_rank > 0 && rhs_rank <= 2); + CHECK_EQ(dot_rank, lhs_rank + rhs_rank - 2); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // Check contracted dimensions are the same. + // + // Determine the index of the contracted dimensions for input tensors. + // dimensions -1 of lhs and dimension 0 of rhs are contracted. + const int64 lhs_contracted_dimension = + ShapeUtil::GetDimensionNumber(lhs->shape(), -1); + const int64 rhs_contracted_dimension = 0; + CHECK_EQ(lhs->shape().dimensions(lhs_contracted_dimension), + rhs->shape().dimensions(rhs_contracted_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracted_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracted_dimension); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracted_dimension); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = Literal::CreateFromShape(dot->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + ReturnT result_val = static_cast(0); + + std::vector lhs_index(lhs_rank, 0); + std::vector rhs_index(rhs_rank, 0); + // Set index for non-contracted dimension for lhs and rhs. + if (lhs_rank > 1) { + lhs_index[0] = multi_index[0]; + } + if (rhs_rank > 1) { + rhs_index[1] = multi_index[multi_index.size() - 1]; + } + + // Accumulates resulting product along the contracted dimension. + for (int64 i = 0; i < contracted_dimension_size; ++i) { + lhs_index[lhs_contracted_dimension] = i; + rhs_index[rhs_contracted_dimension] = i; + + result_val += lhs_literal.Get(lhs_index) * + rhs_literal.Get(rhs_index); + } + + return result_val; + })); + + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + }; + + Status HandlePad(HloInstruction* pad) override { + CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + // Padding value must be scalar. + CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); + CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + pad->padding_config().dimensions_size()); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferPadShape( + /*operand_shape=*/pad->operand(0)->shape(), + /*padding_value_shape=*/pad->operand(1)->shape(), + /*padding_config=*/pad->padding_config())); + CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + // Create new HLO of padded shape with padding value. + ReturnT scalar = + parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); + auto result = Literal::CreateFromShape(pad->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&scalar](tensorflow::gtl::ArraySlice multi_index) { + return scalar; + })); + + auto evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); + + std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), + 0); + std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + + // Loop through each element of the operand, assign them to the + // corresponding index of the resulting padded literal. + const PaddingConfig& pad_config = pad->padding_config(); + + auto func = [&](const std::vector& input_index) { + for (auto i = 0; i < input_index.size(); ++i) { + // Interior padding occurs logically before edge padding, so in the case + // of negative edge padding elements are removed from the + // interior-padded operand. + target_index[i] = + pad_config.dimensions(i).edge_padding_low() + + input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); + + // Account for negative low and high padding: skip assignment if the + // any target index is out of range. + if (!(target_index[i] >= 0 && + target_index[i] < pad->shape().dimensions(i))) { + return true; + } + } + result->Set(target_index, + evaluated_operand.Get(input_index)); + return true; + }; + + std::vector zero_base(evaluated_operand.shape().dimensions_size(), + 0); + std::vector step(evaluated_operand.shape().dimensions_size(), 1); + + ShapeUtil::ForEachIndex( + evaluated_operand.shape(), zero_base, + AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); + + parent_->evaluated_[pad] = std::move(result); + return Status::OK(); + }; + Status Preprocess(HloInstruction* hlo) override { VLOG(2) << hlo->ToString(); return Status::OK(); @@ -446,12 +691,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return binary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index)); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return binary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); } @@ -483,14 +728,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return ternary_op( - LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index), - LiteralUtil::Get(ehs_literal, multi_index)); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index), + ehs_literal.Get(multi_index)); })); return std::move(result); @@ -552,7 +796,7 @@ StatusOr> HloEvaluator::Evaluate( if (operand->opcode() == HloOpcode::kParameter) { const Literal* input_literal = arg_literals_[operand->parameter_number()]; VLOG(2) << "Parameter operand evaluated to: " - << LiteralUtil::ToString(*input_literal); + << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); evaluated_[operand] = MakeUnique(*input_literal); @@ -589,8 +833,7 @@ std::unique_ptr HloEvaluator::TryEvaluate( Status HloEvaluator::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; - VLOG(2) << "Parameter evaluated to: " - << LiteralUtil::ToString(*input_literal); + VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); evaluated_[parameter] = MakeUnique(*input_literal); @@ -606,14 +849,14 @@ Status HloEvaluator::HandleConstant(HloInstruction* constant, Status HloEvaluator::HandleReshape(HloInstruction* reshape) { TF_ASSIGN_OR_RETURN( evaluated_[reshape], - LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)), - AsInt64Slice(reshape->shape().dimensions()))); + GetEvaluatedLiteralFor(reshape->operand(0)) + .Reshape(AsInt64Slice(reshape->shape().dimensions()))); return Status::OK(); } Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { - evaluated_[transpose] = LiteralUtil::Transpose( - GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions()); + evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0)) + .Transpose(transpose->dimensions()); return Status::OK(); } @@ -641,16 +884,16 @@ Status HloEvaluator::HandleConcatenate( ShapeUtil::GetDimension(operand_shape, concat_dim); } - auto result_literal = LiteralUtil::CreateFromDimensions( + auto result_literal = Literal::CreateFromDimensions( reference_shape.element_type(), concat_dimensions); DimensionVector source_indices(rank, 0); DimensionVector dest_indices(concat_dimensions.size(), 0); for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(), - dest_indices, AsInt64Slice(operand_shape.dimensions()))); + TF_RETURN_IF_ERROR(result_literal->Copy( + GetEvaluatedLiteralFor(operand), source_indices, dest_indices, + AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += ShapeUtil::GetDimension(operand_shape, concat_dim); } @@ -775,14 +1018,14 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, Status HloEvaluator::HandleSlice(HloInstruction* slice, HloInstruction* operand) { const Shape& shape = slice->shape(); - auto literal = LiteralUtil::CreateFromDimensions( + auto literal = Literal::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); DimensionVector dest_indices(slice->slice_starts().size(), 0); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(), - dest_indices, AsInt64Slice(shape.dimensions()))); + TF_RETURN_IF_ERROR(literal->Copy(GetEvaluatedLiteralFor(operand), + slice->slice_starts(), dest_indices, + AsInt64Slice(shape.dimensions()))); evaluated_[slice] = std::move(literal); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 91fd56f54c592b8bbe68f6b38e761e1f10a20c8b..976a2325ea970f570748a6872d7bf2459f8ffa4a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -92,7 +92,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); } - // Operations that are type-agnostic. + // Operations that are type-agnostic or always return a specific type, such as + // HandleIsFinite where boolean is always returned. // Status HandleParameter(HloInstruction* parameter) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index b26ece28b756097b06b4a04d4873775e13760014..7269fbeffc51c39af43f2cfd8e5468da54f12855 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -14,27 +14,33 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include #include #include +#include #include #include +#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -class HloEvaluatorTest : public ::testing::Test { +class HloEvaluatorTest : public HloTestBase { protected: HloEvaluatorTest() { evaluator_ = MakeUnique(); } @@ -44,9 +50,9 @@ class HloEvaluatorTest : public ::testing::Test { // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. TEST_F(HloEvaluatorTest, DoesClamp) { - auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); + auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = low->shape(); auto c1 = HloInstruction::CreateConstant(std::move(low)); @@ -58,17 +64,17 @@ TEST_F(HloEvaluatorTest, DoesClamp) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); + auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. TEST_F(HloEvaluatorTest, DoesSelect) { - auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); - auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto pred = Literal::CreateR2({{true, false}, {false, true}}); + auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = on_true->shape(); auto c1 = HloInstruction::CreateConstant(std::move(pred)); @@ -80,16 +86,16 @@ TEST_F(HloEvaluatorTest, DoesSelect) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); + auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. TEST_F(HloEvaluatorTest, DoesAdd) { - auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(lhs)); @@ -100,16 +106,16 @@ TEST_F(HloEvaluatorTest, DoesAdd) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); + auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. TEST_F(HloEvaluatorTest, DoesDivide) { - auto lhs_s64 = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs_s64 = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); auto c1_s64 = HloInstruction::CreateConstant(std::move(lhs_s64)); @@ -120,12 +126,12 @@ TEST_F(HloEvaluatorTest, DoesDivide) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); + auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); - auto lhs_f64 = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs_f64 = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); auto c1_f64 = HloInstruction::CreateConstant(std::move(lhs_f64)); @@ -135,16 +141,15 @@ TEST_F(HloEvaluatorTest, DoesDivide) { result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - expected = - LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + expected = Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. TEST_F(HloEvaluatorTest, DoesAbs) { - auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(operand)); auto instruction = @@ -153,42 +158,40 @@ TEST_F(HloEvaluatorTest, DoesAbs) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); + auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); // For R0 literal. const Shape& r0 = ShapeUtil::MakeShape(F32, {}); - operand = LiteralUtil::CreateR0(-1.0f); + operand = Literal::CreateR0(-1.0f); c1 = HloInstruction::CreateConstant(std::move(operand)); instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get()); result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); - expected = LiteralUtil::CreateR0(1.0f); + expected = Literal::CreateR0(1.0f); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); // For R1 literal with dimension of size 0. Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); - operand = LiteralUtil::CreateR1({}); + operand = Literal::CreateR1({}); c1 = HloInstruction::CreateConstant(std::move(operand)); instruction = HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get()); result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); - expected = LiteralUtil::CreateR1({}); + expected = Literal::CreateR1({}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // namespace // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { - HloComputation::Builder builder( - ::testing::UnitTest::GetInstance()->current_test_info()->name()); - - auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); - auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); +TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { + HloComputation::Builder builder(TestName()); + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -206,21 +209,19 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { std::unique_ptr result = evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); + auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies Reshape operation is correctly evaluated. TEST_F(HloEvaluatorTest, DoesReshape) { - HloComputation::Builder builder( - ::testing::UnitTest::GetInstance()->current_test_info()->name()); - + HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; - TF_ASSIGN_OR_ASSERT_OK(auto literal, - LiteralTestUtil::CreateRandomLiteral( - ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = LiteralUtil::CloneToUnique(*literal); + TF_ASSERT_OK_AND_ASSIGN(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -233,13 +234,717 @@ TEST_F(HloEvaluatorTest, DoesReshape) { evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - LiteralUtil::EachCell( - *result, [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + result->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_TRUE(value == - LiteralUtil::Get(*literal_clone, rindexes)); + EXPECT_TRUE(value == literal_clone->Get(rindexes)); }); } +// Verifies Broadcast operation is correctly evaluated. +TEST_F(HloEvaluatorTest, DoesBroadcast) { + HloComputation::Builder builder(TestName()); + auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto output_literal = Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}}); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + + builder.AddInstruction(HloInstruction::CreateBroadcast( + output_literal->shape(), literal_instruction, {1, 2})); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*result, *output_literal); +} + +TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { + HloComputation::Builder builder(TestName()); + + auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto expected = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), + expected->shape())); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + builder.AddInstruction( + HloInstruction::CreateConvert(expected->shape(), constant)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*result, *expected); +} + +TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { + HloComputation::Builder builder(TestName()); + + auto input_literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); + auto expected = Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), + expected->shape())); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + builder.AddInstruction( + HloInstruction::CreateConvert(expected->shape(), constant)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*result, *expected); +} + +PaddingConfig CreatePaddingConfig( + std::initializer_list> padding_dimensions) { + PaddingConfig padding_config; + + for (auto& paddings_per_dim : padding_dimensions) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(paddings_per_dim[0]); + dimension->set_edge_padding_high(paddings_per_dim[1]); + dimension->set_interior_padding(paddings_per_dim[2]); + } + return padding_config; +} + +TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { + auto operand = Literal::CreateR2({{}, {}}); + auto operand_instruction = HloInstruction::CreateConstant(std::move(operand)); + + constexpr int32 kPadValue = 10; + auto pad_value = Literal::CreateR0(kPadValue); + auto padding_value_instruction = + HloInstruction::CreateConstant(std::move(pad_value)); + + auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}}); + Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); + auto pad_instruction = HloInstruction::CreatePad( + shape, operand_instruction.get(), padding_value_instruction.get(), + padding_config); + + auto result = evaluator_->Evaluate(pad_instruction.get()).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2( + {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { + HloComputation::Builder b(TestName()); + + Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); + auto input = Literal::CreateR4FromArray4D(input_array); + HloInstruction* input_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + constexpr float kPadValue = 1.5; + auto pad_value = Literal::CreateR0(kPadValue); + HloInstruction* pad_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value))); + + Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1}); + auto r4_padding_on_dim0_dim1 = + CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); + b.AddInstruction(HloInstruction::CreatePad( + shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + auto expected_array = MakeUnique>(8, 5, 1, 1); + expected_array->Fill(kPadValue); + (*expected_array)(1, 0, 0, 0) = 1.0f; + (*expected_array)(1, 2, 0, 0) = 2.0f; + (*expected_array)(4, 0, 0, 0) = 3.0f; + (*expected_array)(4, 2, 0, 0) = 4.0f; + (*expected_array)(7, 0, 0, 0) = 5.0f; + (*expected_array)(7, 2, 0, 0) = 6.0f; + + auto expected = Literal::CreateR4FromArray4D(*expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, NegativePadding2D) { + HloComputation::Builder b(TestName()); + + // input_array: + // f32[4,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // { 13, 14, 15 }, + // } + auto input_array = MakeUnique>(4, 3); + input_array->FillUnique(1.0f); + auto input = Literal::CreateR2FromArray2D(*input_array); + HloInstruction* input_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + + auto pad_value_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + + auto r2_padding_on_dim0_dim1 = + CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 5}); + b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction, + pad_value_instruction, + r2_padding_on_dim0_dim1)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } + auto expected_array = MakeUnique>(1, 5); + (*expected_array)(0, 0) = 7.0f; + (*expected_array)(0, 1) = 2.718f; + (*expected_array)(0, 2) = 2.718f; + (*expected_array)(0, 3) = 2.718f; + (*expected_array)(0, 4) = 2.718f; + auto expected = Literal::CreateR2FromArray2D(*expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { + HloComputation::Builder b(TestName()); + + // f32[4,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // { 13, 14, 15 }, + // } + auto input_array = MakeUnique>(4, 3); + input_array->FillUnique(1.0f); + auto input = Literal::CreateR2FromArray2D(*input_array); + HloInstruction* input_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + + auto pad_value_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + + PaddingConfig padding_config = MakeNoPaddingConfig(2); + + // Negative padding that results in zero dimensions. + auto r2_padding_on_dim0_dim1 = + CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}}); + + Shape shape = ShapeUtil::MakeShape(F32, {0, 9}); + b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction, + pad_value_instruction, + r2_padding_on_dim0_dim1)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + auto expected_array = MakeUnique>(0, 9); + auto expected = Literal::CreateR2FromArray2D(*expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DotRank2AndRank1) { + HloComputation::Builder b(TestName()); + + // lhs: + // f32[4,1] { + // { 1 }, + // { 2 }, + // { 3 }, + // { 4 }, + // } + auto lhs_array = MakeUnique>(4, 1); + lhs_array->FillUnique(1.0f); + auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + // rhs: + // f32[2] { 1, 2 }, + auto rhs_literal = Literal::CreateR2({{1, 2}}); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); + b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + // clang-format off + auto expected_array = Array2D({ + {1.f, 2.f}, + {2.f, 4.f}, + {3.f, 6.f}, + {4.f, 8.f}, + }); + // clang-format on + auto expected = Literal::CreateR2FromArray2D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DotRank1AndRank2) { + HloComputation::Builder b(TestName()); + + // lhs: + // f32[3] + // { 1, 2, 3 }, + auto lhs_literal = Literal::CreateR1({1, 2, 3}); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + // rhs: + // f32[3,2] { + // { 1, 2 }, + // { 3, 4 }, + // { 5, 6 }, + // } + auto rhs_array = MakeUnique>(3, 2); + rhs_array->FillUnique(1.0f); + auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {2}); + b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR1({22.f, 28.f}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DotRank2AndRank2) { + HloComputation::Builder b(TestName()); + + // lhs: + // f32[4,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // { 13, 14, 15 }, + // } + auto lhs_array = MakeUnique>(4, 3); + lhs_array->FillUnique(1.0f); + auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + // rhs: + // f32[3,2] { + // { 1, 2 }, + // { 3, 4 }, + // { 5, 6 }, + // } + auto rhs_array = MakeUnique>(3, 2); + rhs_array->FillUnique(1.0f); + auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); + b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + auto expected_array = Array2D({ + {22.f, 28.f}, {58.f, 76.f}, {94.f, 124.f}, {130.f, 172.f}, + }); + auto expected = Literal::CreateR2FromArray2D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, SimpleConv1D) { + HloComputation::Builder b(TestName()); + + Array3D lhs_array = {{{1, 2, 3}}}; + auto lhs_literal = Literal::CreateR3FromArray3D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array3D rhs_array = {{{3.f, 4.f}}}; + auto rhs_literal = Literal::CreateR3FromArray3D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(0); + dnums.set_feature_dimension(1); + dnums.add_spatial_dimensions(2); + + dnums.set_kernel_output_feature_dimension(0); + dnums.set_kernel_input_feature_dimension(1); + dnums.add_kernel_spatial_dimensions(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array3D expected_array = {{{11.f, 18.f, 9.f}}}; + auto expected = Literal::CreateR3FromArray3D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { + HloComputation::Builder b(TestName()); + + Array4D lhs_array(1, 1, 4, 4); + // clang-format off + lhs_array.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array4D rhs_array(1, 1, 2, 2); + // clang-format off + rhs_array.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array4D expected_array(1, 1, 4, 4); + // clang-format off + expected_array.FillWithYX(Array2D({ + {100, 126, 152, 76}, + {204, 230, 256, 124}, + {308, 334, 360, 172}, + {149, 160, 171, 80}, + })); + // clang-format on + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { + HloComputation::Builder b(TestName()); + + // clang-format off + // Input dimensions: [feature=2, height=3, batch=1, width=4] + Array4D input({ + {{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}}, + {{{13, 14, 15, 16}}, + {{17, 18, 19, 20}}, + {{21, 22, 23, 24}}} + }); + // Weight dimensions: + // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3] + Array4D weight({{ + {{1, 7, 13}, + {4, 10, 16}}, + {{2, 8, 14}, + {5, 11, 17}}, + {{3, 9, 15}, + {6, 12, 18}} + }}); + // clang-format on + + auto lhs_literal = Literal::CreateR4FromArray4D(input); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + auto rhs_literal = Literal::CreateR4FromArray4D(weight); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(3); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(2); + dnums.set_feature_dimension(0); + dnums.add_spatial_dimensions(1); + dnums.add_spatial_dimensions(3); + + dnums.set_kernel_output_feature_dimension(0); + dnums.set_kernel_input_feature_dimension(2); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(1); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + // clang-format off + // Result dimensions: [feature=1, height=1, batch=1, width=2] + Array4D expected_array({{{{2514, 2685}}}}); + // clang-format on + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { + HloComputation::Builder b(TestName()); + + Array4D lhs_array(1, 1, 4, 4); + // clang-format off + lhs_array.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array4D rhs_array(1, 1, 2, 2); + // clang-format off + rhs_array.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(2); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array4D expected_array(1, 1, 7, 7); + expected_array.FillWithYX(Array2D({ + {5, 12, 10, 18, 15, 24, 20}, + {35, 48, 42, 56, 49, 64, 56}, + {25, 36, 30, 42, 35, 48, 40}, + {63, 80, 70, 88, 77, 96, 84}, + {45, 60, 50, 66, 55, 72, 60}, + {91, 112, 98, 120, 105, 128, 112}, + {65, 84, 70, 90, 75, 96, 80}, + })); + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { + HloComputation::Builder b(TestName()); + + Array4D lhs_array(1, 1, 4, 4); + // clang-format off + lhs_array.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array4D rhs_array(1, 1, 2, 2); + // clang-format off + rhs_array.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(1); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(2); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array4D expected_array(1, 1, 8, 8); + expected_array.FillWithYX(Array2D({ + {8, 7, 16, 14, 24, 21, 32, 28}, + {6, 5, 12, 10, 18, 15, 24, 20}, + {40, 35, 48, 42, 56, 49, 64, 56}, + {30, 25, 36, 30, 42, 35, 48, 40}, + {72, 63, 80, 70, 88, 77, 96, 84}, + {54, 45, 60, 50, 66, 55, 72, 60}, + {104, 91, 112, 98, 120, 105, 128, 112}, + {78, 65, 84, 70, 90, 75, 96, 80}, + })); + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, + DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { + HloComputation::Builder b(TestName()); + + Array4D lhs_array(1, 1, 4, 4); + // clang-format off + lhs_array.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array4D rhs_array(1, 1, 2, 3); + // clang-format off + rhs_array.FillWithYX(Array2D({ + {5, 6, 7}, + {8, 9, 10}, + })); + // clang-format on + auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(2); + dim.set_padding_high(2); + dim.set_window_dilation(2); + dim.set_base_dilation(2); + *window.add_dimensions() = dim; + dim.set_size(3); + dim.set_stride(3); + dim.set_padding_low(2); + dim.set_padding_high(-1); + dim.set_window_dilation(1); + dim.set_base_dilation(3); + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array4D expected_array(1, 1, 9, 3); + expected_array.FillWithYX(Array2D({ + {10, 20, 30}, + {0, 0, 0}, + {57, 74, 91}, + {0, 0, 0}, + {125, 142, 159}, + {0, 0, 0}, + {193, 210, 227}, + {0, 0, 0}, + {91, 98, 105}, + })); + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index 9e25f1aceb1595b89aee601b294792e9e801c6f3..7a83a92404e3cd88f3075322111880cc95637c23 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -19,14 +19,11 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/metric_table_report.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -55,96 +52,19 @@ string HloExecutionProfile::ToString( return ""; } - using Item = std::pair; - std::vector items; - for (Item item : hlo_to_cycles_taken_) { - // Only include the HLOs which are part of the desired computation. - if (item.first->parent() == &computation) { - items.push_back(item); - } - } - auto custom_less = [](const Item& lhs, const Item& rhs) { - return lhs.second > rhs.second; - }; - std::sort(items.begin(), items.end(), custom_less); - string result; - const int64 total_cycles = total_cycles_executed(computation); - double clock_rate_ghz = device_description.clock_rate_ghz(); - CHECK_GE(clock_rate_ghz, 1e-9); - - const auto cycles_to_microseconds = [&](double cycles) { - return cycles / clock_rate_ghz / 1000.0; - }; - - auto append_item = [&](int64 cycles, int64 flops, int64 bytes_accessed, - const string& name) { - double nsecs = cycles / clock_rate_ghz; - string bytes_per_sec; - string bytes_per_cycle; - if (cycles <= 0 || bytes_accessed < 0) { - bytes_per_sec = ""; - bytes_per_cycle = ""; - } else { - bytes_per_sec = tensorflow::strings::HumanReadableNumBytes( - bytes_accessed / (nsecs / 1e9)); - bytes_per_cycle = - tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles); - } - - double cycles_percent = 0; - if (total_cycles > 0) { - cycles_percent = cycles / static_cast(total_cycles) * 100; - } - - tensorflow::strings::StrAppend( - &result, - tensorflow::strings::Printf( - "%15lld cycles (%6.2f%%) :: %12.1f usec @ f_nom :: %18s :: %12s/s " - ":: " - "%12s/cycle :: " - "%s", - cycles, cycles_percent, cycles_to_microseconds(cycles), - flops <= 0 ? "" : HumanReadableNumFlops(flops, nsecs).c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str())); - }; - tensorflow::strings::StrAppend( - &result, tensorflow::strings::Printf( - "HLO execution profile for %s: (%s @ f_nom)\n\t", - computation.name().c_str(), - tensorflow::strings::HumanReadableElapsedTime( - total_cycles / clock_rate_ghz / 1e9) - .c_str())); - - append_item(total_cycles, -1, -1, "[total]"); - for (const auto& item : items) { + HumanReadableProfileBuilder builder(computation.name(), + total_cycles_executed(computation), + device_description.clock_rate_ghz()); + for (const auto& item : hlo_to_cycles_taken_) { const HloInstruction* hlo = item.first; - tensorflow::strings::StrAppend(&result, "\n\t"); - const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo); - const int64 bytes_accessed = - (hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo); - const string display = (hlo == nullptr) ? "" : hlo->ToString(); - append_item(item.second, flops, bytes_accessed, display); - } + int64 cycles = item.second; - if (total_cycles <= 0) { - result += "****** 0 total cycles ******\n"; - } else { - MetricTableReport table; - table.SetMetricName("microseconds"); - table.SetEntryName("ops"); - table.SetShowCategoryTable(); - for (const auto& item : items) { - MetricTableReport::Entry entry; - entry.text = item.first->ToString(); - entry.short_text = item.first->ToString(/*compact_operands=*/true); - entry.category_text = item.first->ToCategory(); - entry.metric = cycles_to_microseconds(item.second); - table.AddEntry(std::move(entry)); - } - result += table.MakeReport(cycles_to_microseconds(total_cycles)); + builder.AddOp(/*op_name=*/hlo->ToString(), + /*short_name=*/hlo->ToString(/*compact_operands=*/true), + hlo->ToCategory(), cycles, cost_analysis.flop_count(*hlo), + cost_analysis.bytes_accessed(*hlo)); } - - return result; + return builder.ToString(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index eb2e5dfb37f33fd138e20ee930a2242cb1db89ea..c6c06658316e28d6c40b8d6ce371e3accdd42fcb 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -16,10 +16,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include +#include +#include +#include +#include +#include #include +#include +#include +#include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" @@ -27,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -34,20 +42,100 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" using ::tensorflow::Env; -using ::tensorflow::WriteStringToFile; +using ::tensorflow::gtl::nullopt; +using ::tensorflow::gtl::optional; using ::tensorflow::io::JoinPath; -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; using ::tensorflow::str_util::Join; +using ::tensorflow::str_util::StringReplace; +using ::tensorflow::WriteStringToFile; namespace xla { namespace hlo_graph_dumper { namespace { +// Helpers for Printf and Appendf. +template +struct PrintfConvert { + const T& operator()(const T& t) const { return t; } +}; +template <> +struct PrintfConvert { + const char* operator()(const string& s) const { return s.c_str(); } +}; + +// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() +// on strings. +template +string Printf(const char* fmt, const Ts&... ts) { + return tensorflow::strings::Printf(fmt, PrintfConvert()(ts)...); +} +template +void Appendf(string* s, const char* fmt, const Ts&... ts) { + tensorflow::strings::Appendf(s, fmt, PrintfConvert()(ts)...); +} + +// Used to indicate how we should treat a given HLOInstruction in the graph. +// should we treat it like normal, hide it, and so on? +enum NodeFilterResult { + kNormalNode, + kHideNode, + // Make the node easy to find in the final graph. + kHighlightNode, + // "Gray out" the node to indicate that some of its operands have been + // omitted. + kSomeOperandsOmitted, + // Style the node the same as kSomeOperandsOmitted, but also don't connect it + // to its operands, even if they're present in the graph. + kOmitNodeOperands, + // Same style as kSomeOperandsOmitted, but used to indicate that some of the + // node's *users* have been omitted. + kSomeUsersOmitted, +}; + +// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult. +// It lets callers tell the graph-drawing routines which nodes they want to be +// shown, hidden, or highlighted. +class NodeFilter { + public: + NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {} + + explicit NodeFilter( + std::function filter) + : filter_(std::move(filter)) {} + + bool Show(const HloInstruction* instr) const { + return filter_(instr) != kHideNode; + } + bool Highlight(const HloInstruction* instr) const { + return filter_(instr) == kHighlightNode; + } + bool OmitOperands(const HloInstruction* instr) const { + return filter_(instr) == kOmitNodeOperands; + } + bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const { + auto result = filter_(instr); + return result == kOmitNodeOperands || result == kSomeOperandsOmitted; + } + bool Deemphasized(const HloInstruction* instr) const { + auto result = filter_(instr); + return result == kOmitNodeOperands || result == kSomeOperandsOmitted || + result == kSomeUsersOmitted; + } + + bool ShowFusionSubcomputation(const HloInstruction* instr) const { + CHECK_EQ(instr->opcode(), HloOpcode::kFusion); + return Show(instr) && !SomeOrAllOperandsOmitted(instr); + } + + private: + std::function filter_; +}; + // Node color schemes, used by NodeColorAttributes. enum ColorScheme { kBlue, @@ -62,420 +150,780 @@ enum ColorScheme { kRed, kWhite, kYellow, + + // Causes the node's border to be a dashed line, and its content to be gray + // text on a white background, suggesting that this is an "unimportant" node. + kDashedBorder, }; // Given a ColorScheme, returns an attribute string for a node of that color. -// Sets the node's fill, stroke, and text colors. +// Sets the node's style and fill/stroke/text colors. // // Colors are from https://material.io/color. string NodeColorAttributes(ColorScheme color) { using std::make_tuple; - const char *fill_color, *stroke_color, *font_color; - std::tie(fill_color, stroke_color, font_color) = - [color]() -> std::tuple { + const char *style, *fill_color, *stroke_color, *font_color; + std::tie(style, fill_color, stroke_color, font_color) = [color] { switch (color) { case kBlue: - return make_tuple("#bbdefb", "#8aacc8", "black"); + return make_tuple("filled", "#bbdefb", "#8aacc8", "black"); case kBrown: - return make_tuple("#bcaaa4", "#8c7b75", "black"); + return make_tuple("filled", "#bcaaa4", "#8c7b75", "black"); case kDarkBlue: - return make_tuple("#1565c0", "#003c8f", "white"); + return make_tuple("filled", "#1565c0", "#003c8f", "white"); case kDarkGreen: - return make_tuple("#2e7d32", "#005005", "white"); + return make_tuple("filled", "#2e7d32", "#005005", "white"); case kDarkRed: - return make_tuple("#b71c1c", "#7f0000", "white"); + return make_tuple("filled", "#b71c1c", "#7f0000", "white"); case kGray: - return make_tuple("#cfd8dc", "#9ea7aa", "black"); + return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black"); case kGreen: - return make_tuple("#c8e6c9", "#97b498", "black"); + return make_tuple("filled", "#c8e6c9", "#97b498", "black"); case kOrange: - return make_tuple("#ffe0b2", "#cbae82", "black"); + return make_tuple("filled", "#ffe0b2", "#cbae82", "black"); case kPurple: - return make_tuple("#e1bee7", "#af8eb5", "black"); + return make_tuple("filled", "#e1bee7", "#af8eb5", "black"); case kRed: - return make_tuple("#ffcdd2", "#cb9ca1", "black"); + return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black"); case kWhite: - return make_tuple("white", "black", "black"); + return make_tuple("filled", "white", "black", "black"); case kYellow: - return make_tuple("#fff9c4", "#cbc693", "black"); + return make_tuple("filled", "#fff9c4", "#cbc693", "black"); + case kDashedBorder: + // "filled,dashed" looks the same as "dashed", since we have a white + // background. But we use "filled,dashed" so that when you hover over + // any part of the node (not just the text inside the node), our css + // :hover rule is triggered. + return make_tuple("filled,dashed", "white", "#757575", "#757575"); } }(); return Printf( - "style=filled, fontcolor=\"%s\", color=\"%s\", fillcolor=\"%s\"", + R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", style, font_color, stroke_color, fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a // graphviz HTML-like string. string HtmlLikeStringSanitize(tensorflow::StringPiece s) { - return tensorflow::str_util::StringReplace( - tensorflow::str_util::StringReplace(s, "<", "<", /*replace_all=*/true), - ">", ">", /*replace_all=*/true); + return StringReplace(StringReplace(s, "<", "<", /*replace_all=*/true), ">", + ">", /*replace_all=*/true); } -// Returns the dot graph identifier for the given instruction. -string InstructionId(const HloInstruction* instruction) { - return Printf("%lld", reinterpret_cast(instruction)); -} +// Tries to generates a human-readable one-word description of the given +// computation. +// +// Currently we support: +// +// "return param0 + param1;" --> "add" +// "return param0 * param1;" --> "multiply" +// "return min(param0, param1);" --> "min" +// "return max(param0, param1);" --> "max" +// "return param0 <= param1;" --> "less-or-equal" +// "return param0 >= param1;" --> "greater-or-equal" +// "return param0 > param1;" --> "greater-than" +// "return param0 < param1;" --> "less-than" +// "return param0 == param1;" --> "equal-to" +// "return param0 != param1;" --> "not-equal-to" +// +// where param0 and param1 are effective scalars. For the ops that are +// commutative, we also support them with param0 and param1 swapped. +// +// This is useful primarily for reduce and map nodes. These take a +// subcomputation which is almost always one of the four above, and pattern +// matching it to a short string lets us tell the user what the subcomputation +// is without drawing it as a graph. +optional MatchTrivialComputation(const HloComputation* computation) { + if (computation->instruction_count() != 3) { + return nullopt; + } -// Returns the dot graph identifier for the given computation. -string ComputationId(const HloComputation* computation) { - return Printf("%lld", reinterpret_cast(computation)); -} + HloInstruction* root = computation->root_instruction(); + if (root->operand_count() != 2) { + return nullopt; + } -// Returns the dot graph edges and nodes for the given instruction sequence. -// Edges which extend between computations are added to the vector -// intercomputation_edges. This is necessary because graphviz does not render -// the graph properly unless these inter-computation edges appear after all -// subgraph statements. -string InstructionSequenceGraph( - const std::list>& instructions, - bool show_addresses, bool show_layouts, - std::vector* intercomputation_edges, - const HloExecutionProfile* hlo_execution_profile) { - string graph_body; - - // Create a single "record" node for the parameters. This node is a - // partitioned rectangle with one partition per parameter node. The keeps - // all the parameter instructions together. - std::vector param_instructions; - for (auto& instruction : instructions) { - if (instruction->opcode() == HloOpcode::kParameter) { - size_t param_number = instruction->parameter_number(); - - if (param_instructions.size() < param_number + 1) { - param_instructions.resize(param_number + 1, nullptr); - } - param_instructions[param_number] = instruction.get(); - } - } - string param_node_name; - if (!param_instructions.empty()) { - std::vector param_ports; - param_node_name = - StrCat("parameters_", InstructionId(param_instructions[0])); - for (auto& param : param_instructions) { - string label = StrCat(param->parameter_name(), "\\n", - ShapeUtil::HumanString(param->shape())); - if (show_addresses) { - Appendf(&label, "\\n[%p]", param); - } - if (show_layouts) { - StrAppend(&label, "\\nlayout=\\{", - Join(param->shape().layout().minor_to_major(), ","), "\\}"); - } - param_ports.push_back( - Printf("<%s> %s", InstructionId(param).c_str(), label.c_str())); - } - // (If we wanted the word "parameters" to be bold like the other op names, - // we'd have to make this into an HTML-like table. It is possible but - // complicated; see http://www.graphviz.org/doc/info/shapes.html#html.) - StrAppend(&graph_body, param_node_name, " [shape=record ", - NodeColorAttributes(kOrange), "label=\"{parameters | {", - Join(param_ports, "|"), "}}\"];\n"); - } - - for (auto& instruction : instructions) { - ColorScheme color = kYellow; - string shape = "box"; - string name = - StrCat("", HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()), - " ", HtmlLikeStringSanitize(instruction->name())); - if (HloOpcode::kConvolution == instruction->opcode()) { - StrAppend( - &name, "
", - HtmlLikeStringSanitize( - instruction->ConvolutionDimensionNumbersToString()), - "
", - HtmlLikeStringSanitize(window_util::ToString(instruction->window()))); - } - - if (!instruction->metadata().op_name().empty()) { - StrAppend(&name, "
", - HtmlLikeStringSanitize(instruction->metadata().op_name())); - } - if (!instruction->metadata().source_file().empty() && - instruction->metadata().source_line() != 0) { - StrAppend(&name, "
", instruction->metadata().source_file(), ":", - instruction->metadata().source_line()); - } - - // Pick different colors or shapes for instructions which are particularly - // expensive (eg, dot) and those which are unusual in some way or unique - // (eg, parameter). - switch (instruction->opcode()) { - // "Normal" instructions. Mostly cheap and elementwise. No call to - // embedded computations. In this case, use default color, shape and - // label. - case HloOpcode::kAbs: - case HloOpcode::kAdd: - case HloOpcode::kCeil: - case HloOpcode::kClamp: - case HloOpcode::kConvert: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kExp: - case HloOpcode::kFloor: + // Check that both of the operands to the root are parameters. + const HloInstruction* operand0 = root->operand(0); + const HloInstruction* operand1 = root->operand(1); + if (operand0->opcode() != HloOpcode::kParameter || + operand1->opcode() != HloOpcode::kParameter) { + return nullopt; + } + + // Check that the two operands of root are param0 and param1. All of the + // opcodes we recognize are commutative, so we're OK with either order. + auto n0 = operand0->parameter_number(); + auto n1 = operand1->parameter_number(); + if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { + return nullopt; + } + + // If the params are reversed, check that the operation being performed is + // commutative. + if (n0 == 1) { + switch (root->opcode()) { + case HloOpcode::kLe: case HloOpcode::kGe: case HloOpcode::kGt: - case HloOpcode::kIndex: - case HloOpcode::kIsFinite: - case HloOpcode::kLe: - case HloOpcode::kLog: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kNegate: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSelect: - case HloOpcode::kSign: - case HloOpcode::kSlice: - case HloOpcode::kSort: - case HloOpcode::kSubtract: - case HloOpcode::kTanh: + return nullopt; + default: break; + } + } + + // Check that the root and params are all effective scalars. + if (!ShapeUtil::IsEffectiveScalar(root->shape()) || + !ShapeUtil::IsEffectiveScalar(operand0->shape()) || + !ShapeUtil::IsEffectiveScalar(operand1->shape())) { + return nullopt; + } + + // If we recognize the root's opcode, we've successfully pattern-matched! + switch (root->opcode()) { + case HloOpcode::kAdd: + return "add"; + case HloOpcode::kMultiply: + return "multiply"; + case HloOpcode::kMinimum: + return "min"; + case HloOpcode::kMaximum: + return "max"; + case HloOpcode::kLe: + return "less-or-equal"; + case HloOpcode::kGe: + return "greater-or-equal"; + case HloOpcode::kGt: + return "greater-than"; + case HloOpcode::kLt: + return "less-than"; + case HloOpcode::kEq: + return "equal-to"; + case HloOpcode::kNe: + return "not-equal-to"; + default: + return nullopt; + } +} + +// Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). +class HloDotDumper { + public: + HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, + bool show_addresses, bool show_layouts, + const HloExecutionProfile* profile, NodeFilter filter) + : computation_(computation), + label_(label.ToString()), + show_addresses_(show_addresses), + show_layouts_(show_layouts), + profile_(profile), + filter_(std::move(filter)) {} + + string Dump(); + + private: + // Returns the dot graph identifier for the given instruction. + string InstructionId(const HloInstruction* instruction) { + return StrCat(reinterpret_cast(instruction)); + } + + // Returns the dot graph identifier for the given computation. + string SubcomputationId(const HloComputation* computation) { + return StrCat("cluster_", reinterpret_cast(computation)); + } + + // Generates graph header/footer. These should be called *after* dumping all + // of the instructions and subcomputations for the graph, as they both use + // data generated while dumping the graph. + string Header(); + string Footer(); + + // Maps HloComputations we should dump to their parent instruction in the + // outer computation. + std::unordered_map + SubcomputationsToDump(); + + string DumpSubcomputation(const HloComputation* subcomp, + const HloInstruction* parent_instr); + string DumpComputation(const HloComputation* comp); + string DumpInstruction(const HloInstruction* instr); + ColorScheme GetInstructionColor(const HloInstruction* instr); + string GetInstructionNodeShape(const HloInstruction* instr); + string GetInstructionNodeLabel(const HloInstruction* instr); + string GetInstructionNodeExtraInfo(const HloInstruction* instr); + string GetInstructionNodeInlinedConstants(const HloInstruction* instr); + void AddInstructionIncomingEdges(const HloInstruction* instr); + + // If instr has just one computation and it's trivial (e.g. "return param0 + + // param1"), returns a string you can put into the node's body that names the + // subcomputation, e.g. "Subcomputation: add". + string GetInstructionTrivialComputationStr(const HloInstruction* instr); + + const HloComputation* computation_; // never null + const string label_; // overall name for the graph + const bool show_addresses_; + const bool show_layouts_; + const HloExecutionProfile* profile_; // may be null + const NodeFilter filter_; + + // Each HloInstruction dumped gets a monotically-increasing node ID. This + // must start at 1, because that's where graphviz's accounting starts. + int64 next_node_id_ = 1; + std::unordered_map node_ids_; + + // Each (from, to) edge gets a monotonically-increasing ID. This is a + // multimap because it's possible for the same edge to appear multiple times + // in the graph (e.g. x^2 may be represented as mul(x, x)). + int64 next_edge_id_ = 1; + std::unordered_multimap< + std::pair, int64, + tensorflow::hash>> + edge_ids_; + + // Each HloComputation that's emitted gets a monotonically-increasing ID. + int64 next_cluster_id_ = 1; + std::unordered_map cluster_ids_; + + // Edges to print from Footer(). Edges come at the end because graphviz is + // unhappy if an edge from a subcomputation to a node in the outer computation + // appears before both the inner computation and the destination node are + // defined. + std::vector edges_; +}; + +string HloDotDumper::Dump() { + string body; + for (const auto& kv : SubcomputationsToDump()) { + const HloComputation* subcomp = kv.first; + const HloInstruction* parent = kv.second; + StrAppend(&body, DumpSubcomputation(subcomp, parent)); + } + StrAppend(&body, DumpComputation(computation_)); + + // By contract, Header() and Footer() have to be called after we've dumped all + // our instructions, because they use state generated during that process. + string g = Header(); + StrAppend(&g, body); + StrAppend(&g, Footer()); + return g; +} + +string HloDotDumper::Header() { + const char* fmt = R"(digraph G { +rankdir = TB; +compound = true; +label = <%s>; +labelloc = t; +// Disable the tooltip. Interestingly, "" doesn't work! +tooltip = " "; +// DOT graphs accept a stylesheet as a URI. So naturally, an inline +// stylesheet is a data URI! +stylesheet=" + data:text/css, + @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); + svg text { + font-family: 'Roboto'; + font-size: 12px; + } + + %s +" + +)"; + + string graph_label = StrCat(label_, "
", computation_->name()); + if (profile_ != nullptr) { + auto cycles = profile_->total_cycles_executed(*computation_); + Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, + tensorflow::strings::HumanReadableNum(cycles)); + } + + // Create CSS rules that say, when you hover over the given node or cluster, + // turn the given edge the given color. + // + // We rely on a few properties of how graphviz generates SVGs: + // + // - Nodes are named "nodeN", where N corresponds to the 1-based index of + // the node in our DOT (i.e. the first node in the DOT is "node1", etc.). + // Edges are similarly named "edgeN", and clusters are named "clustN". + // - Nodes come before their in- and out-edges in the SVG. We need this + // because the "X ~ Y" CSS selector finds a sibling of X that *comes + // after X in the DOM* and matches Y. + std::vector edge_css_rules; + const char* kBlue = "#1976d2"; + const char* kRed = "#d32f2f"; + for (const auto& kv : edge_ids_) { + const HloInstruction* from_node = kv.first.first; + const HloInstruction* to_node = kv.first.second; + int64 edge_id = kv.second; + + auto add_hover_css_rule = [&](string elem_type, int64 elem_id, + const char* color) { + // One could imagine other ways of writing this CSS rule that involve less + // duplication, but this way seems to be relatively performant. + edge_css_rules.push_back(Printf( + " #%s%d:hover ~ #edge%lld text { fill: %s; }\n" + " #%s%d:hover ~ #edge%lld path { stroke: %s; stroke-width: .2em; }\n" + " #%s%d:hover ~ #edge%lld polygon { " + "fill: %s; stroke: %s; stroke-width: .2em; }\n", + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, color)); + }; + + int64 from_node_id = node_ids_.at(from_node); + int64 to_node_id = node_ids_.at(to_node); + add_hover_css_rule("node", from_node_id, kBlue); + add_hover_css_rule("node", to_node_id, kRed); + + // If this edge crosses a fusion cluster boundary, highlight it when the + // cluster is hovered over. + if (from_node->IsFused() && + from_node->fusion_instruction()->fused_expression_root() == from_node) { + int64 cluster_id = cluster_ids_.at(from_node->parent()); + add_hover_css_rule("clust", cluster_id, kBlue); + } + if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) { + int64 cluster_id = cluster_ids_.at(to_node->parent()); + add_hover_css_rule("clust", cluster_id, kRed); + } + } + + return Printf(fmt, graph_label, Join(edge_css_rules, "\n")); +} + +string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } + +std::unordered_map +HloDotDumper::SubcomputationsToDump() { + // Dump the subcomputations of each instruction that's shown and doesn't have + // its operands omitted. If an instruction has just one subcomputation and + // it's trivial, omit it: We'll display that subcomputation inlined into the + // instruction's node when we draw it. + std::unordered_map to_dump; + for (const auto& instr : computation_->instructions()) { + if (!filter_.Show(instr.get()) || + filter_.SomeOrAllOperandsOmitted(instr.get())) { + continue; + } + if (instr->opcode() == HloOpcode::kFusion) { + to_dump[instr->fused_instructions_computation()] = instr.get(); + } + + for (const HloComputation* comp : instr->called_computations()) { + if (!MatchTrivialComputation(comp)) { + to_dump[comp] = instr.get(); + } + } + } + return to_dump; +} + +string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, + const HloInstruction* parent_instr) { + const char* computation_fmt = R"(subgraph %s { +%s; +label = <%s>; +labelloc = t; +tooltip = " "; +%s +} // %s + +)"; + + cluster_ids_[subcomp] = next_cluster_id_++; + + string id = SubcomputationId(subcomp); + + string subcomp_label, style; + if (parent_instr->opcode() == HloOpcode::kFusion) { + subcomp_label = Printf("Fused expression for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(parent_instr->ToCategory())); + + // Subcomputation's fill/stroke color is light/dark red/gray, depending on + // whether or not the subcomputation's fusion node is highlighted. + bool highlight = filter_.Highlight(parent_instr); + const char* fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; + const char* strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; + style = Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s")", + fillcolor, strokecolor); + } else { + subcomp_label = Printf("Subcomputation for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(subcomp->name())); + style = "style=rounded; color=black;"; + } + + string comp_body = DumpComputation(subcomp); + string computation = + Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, so + // there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { + edge_ids_.insert( + {{subcomp->root_instruction(), parent_instr}, next_edge_id_++}); + const char* edge_fmt = + R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; + edges_.push_back( + Printf(edge_fmt, InstructionId(subcomp->root_instruction()), + InstructionId(parent_instr), SubcomputationId(subcomp), + subcomp->name(), parent_instr->name())); + } + + return computation; +} + +string HloDotDumper::DumpComputation(const HloComputation* comp) { + string g; + for (const auto& instr : comp->instructions()) { + if (!filter_.Show(instr.get())) { + continue; + } + StrAppend(&g, DumpInstruction(instr.get())); + } + return g; +} + +string HloDotDumper::DumpInstruction(const HloInstruction* instr) { + // We don't display constants as separate nodes; they're merged into their + // users. + if (instr->opcode() == HloOpcode::kConstant) { + return ""; + } + // Omit the fusion node if its subcomputation is drawn, since the + // subcomputation will be drawn inline. + if (instr->opcode() == HloOpcode::kFusion && + filter_.ShowFusionSubcomputation(instr)) { + return ""; + } + + node_ids_[instr] = next_node_id_++; + + ColorScheme color = GetInstructionColor(instr); + string node_shape = GetInstructionNodeShape(instr); + string node_label = GetInstructionNodeLabel(instr); + string extra_info = GetInstructionNodeExtraInfo(instr); + string inlined_constants = GetInstructionNodeInlinedConstants(instr); + string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); + AddInstructionIncomingEdges(instr); + + // Override the node's styling if it should be (de-)emphasized. + if (filter_.Deemphasized(instr)) { + color = kDashedBorder; + } + if (filter_.Highlight(instr)) { + node_shape = "diamond"; + color = kDarkRed; + } + + // Build the text that will be displayed inside the node. + string node_body = node_label; + for (const string& s : + {trivial_subcomputation, extra_info, inlined_constants}) { + if (!s.empty()) { + StrAppend(&node_body, "
", s); + } + } + + return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + "\n", + InstructionId(instr), node_body, node_shape, + NodeColorAttributes(color)); +} + +string HloDotDumper::GetInstructionNodeInlinedConstants( + const HloInstruction* instr) { + auto stringify_constant = [](const HloInstruction* constant) { + if (ShapeUtil::IsEffectiveScalar(constant->shape())) { + auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( + constant->shape(), /*linear_index=*/0); + return Printf("%s{%s}", ShapeUtil::HumanString(constant->shape()), + constant->literal().GetAsString(elem_idx)); + } + if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { + return constant->name(); + } + return StrCat("constant ", constant->name()); + }; + + // Special case: If instr is a parameter to a fusion node, check whether the + // corresponding operand to the fusion node is a constant. + if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { + const HloInstruction* fusion = instr->fusion_instruction(); + const HloInstruction* operand = fusion->operand(instr->parameter_number()); + if (operand->opcode() != HloOpcode::kConstant) { + return ""; + } + return stringify_constant(operand); + } + + std::vector lines; + for (int64 i = 0; i < instr->operand_count(); ++i) { + const HloInstruction* operand = instr->operand(i); + if (operand->opcode() != HloOpcode::kConstant) { + continue; + } + lines.push_back( + Printf("operand %lld = %s", i, stringify_constant(operand))); + } + return Join(lines, "
"); +} + +ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { + // Pick different colors or shapes for instructions which are particularly + // expensive (eg, dot) and those which are unusual in some way or unique + // (eg, parameter). + switch (instr->opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kConvert: + case HloOpcode::kCos: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kIndex: + case HloOpcode::kIsFinite: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLogicalAnd: + case HloOpcode::kLogicalNot: + case HloOpcode::kLogicalOr: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSelect: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kRng: + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + return kYellow; + case HloOpcode::kBitcast: + case HloOpcode::kTuple: + case HloOpcode::kTrace: + case HloOpcode::kGetTupleElement: + return kWhite; + case HloOpcode::kConcatenate: + case HloOpcode::kCopy: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kUpdate: + return kGreen; + case HloOpcode::kConvolution: + case HloOpcode::kDot: + return kDarkBlue; + case HloOpcode::kReducePrecision: + return kRed; + case HloOpcode::kParameter: + return kOrange; + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kReduce: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReduceWindow: + return kPurple; + case HloOpcode::kMap: + case HloOpcode::kFusion: + return kGray; + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kCrossReplicaSum: + return kBrown; + case HloOpcode::kCustomCall: + case HloOpcode::kWhile: + case HloOpcode::kCall: + return kDarkGreen; + case HloOpcode::kConstant: + LOG(FATAL) << "Constants don't get their own nodes in the graph."; + } +} + +string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { + // Give while loops a different shape so they're easier to pick out. + switch (instr->opcode()) { + case HloOpcode::kWhile: + return "ellipse"; + default: + return "rect"; + } +} + +string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { + // If we have a parameter, put the param number in the name. + if (instr->opcode() == HloOpcode::kParameter) { + return Printf("Parameter %lld", instr->parameter_number()); + } + + // The HLO instruction name contains usually the opcode, e.g. "%add.42" is + // an add instruction. In this case we render just the name. + if (tensorflow::StringPiece(instr->name()) + .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) { + return Printf("%s", HtmlLikeStringSanitize(instr->name())); + } + + // If the name does not contain the opcode, render both. + return Printf("%s
%s", + HtmlLikeStringSanitize(instr->ExtendedOpcodeStr()), + HtmlLikeStringSanitize(instr->name())); +} + +string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { + string opcode_specific_info = [&]() -> string { + switch (instr->opcode()) { case HloOpcode::kRng: - StrAppend(&name, "
", - RandomDistribution_Name(instruction->random_distribution())); - break; + return RandomDistribution_Name(instr->random_distribution()); + case HloOpcode::kConvolution: + return StrCat( + HtmlLikeStringSanitize( + instr->ConvolutionDimensionNumbersToString()), + "
", + HtmlLikeStringSanitize(window_util::ToString(instr->window()))); case HloOpcode::kBroadcast: case HloOpcode::kTranspose: - StrAppend(&name, "
", "dims={", - Join(instruction->dimensions(), ","), "}"); - break; - case HloOpcode::kBitcast: - case HloOpcode::kTuple: - case HloOpcode::kTrace: - color = kWhite; - break; - case HloOpcode::kGetTupleElement: - color = kWhite; - StrAppend(&name, "
index=", instruction->tuple_index()); - break; - case HloOpcode::kConcatenate: - case HloOpcode::kCopy: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kPad: - case HloOpcode::kReshape: - case HloOpcode::kReverse: - case HloOpcode::kUpdate: - color = kGreen; - break; - case HloOpcode::kConstant: - color = kBlue; - break; - case HloOpcode::kConvolution: - case HloOpcode::kDot: - color = kDarkBlue; - break; - case HloOpcode::kParameter: - // A single record node is created for all the parameter nodes with a - // port for each parameter instruction. No need to emit anything in this - // case. - continue; case HloOpcode::kReduce: - StrAppend(&name, " dims=", Join(instruction->dimensions(), ",")); - color = kPurple; - break; - case HloOpcode::kSelectAndScatter: - case HloOpcode::kReduceWindow: - color = kPurple; - break; - case HloOpcode::kWhile: - shape = "ellipse"; - color = kDarkGreen; - break; - case HloOpcode::kMap: - case HloOpcode::kFusion: - color = kGray; - break; - case HloOpcode::kSend: - case HloOpcode::kRecv: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kCrossReplicaSum: - color = kBrown; - break; - case HloOpcode::kCall: - color = kDarkGreen; - break; + return Printf("dims={%s}", Join(instr->dimensions(), ",")); + case HloOpcode::kGetTupleElement: + return Printf("index=%lld", instr->tuple_index()); + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormGrad: + return Printf("feature_index=%lld", instr->feature_index()); case HloOpcode::kCustomCall: - color = kDarkGreen; - StrAppend(&name, "
", - "custom_call_target=", instruction->custom_call_target()); - break; + return Printf("custom_call_target=%s", instr->custom_call_target()); + default: + return ""; } + }(); - // Create instruction node with appropriate label, shape, and color. - // label is interpreted as an HTML-like string, so newlines must be - // delimited with
, rather than \n. - string label = - StrCat(name, "
", ShapeUtil::HumanString(instruction->shape())); + std::vector lines; + if (!opcode_specific_info.empty()) { + lines.push_back(opcode_specific_info); + } - if (instruction->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(instruction->shape())) { - auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( - instruction->shape(), /*linear_index=*/0); - StrAppend(&label, " = {", - LiteralUtil::GetAsString(instruction->literal(), elem_idx), - "}"); - } - - if (show_addresses) { - Appendf(&label, "
[%p]", instruction.get()); - } - if (show_layouts && LayoutUtil::HasLayout(instruction->shape())) { - string layout_string; - if (ShapeUtil::IsTuple(instruction->shape())) { - // For tuples, emit the full shape because the layout of a tuple is not - // represented in a single Layout field. - layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); - } else { - layout_string = - Join(instruction->shape().layout().minor_to_major(), ","); - } - StrAppend(&label, "
layout={", layout_string, "}"); - } - if (hlo_execution_profile != nullptr) { - auto hlo_cycles_executed = - hlo_execution_profile->GetProfileResult(*instruction); - auto total_cycles_executed = - hlo_execution_profile->total_cycles_executed(*instruction->parent()); - if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { - Appendf(&label, "
%% of cycles executed=%.2f", - (static_cast(hlo_cycles_executed) / - static_cast(total_cycles_executed)) * - 100); - } - } + // Some instructions have giant tuples as their shapes, so truncate the HLO's + // shape to kMaxShapeLen characters. + constexpr int kMaxShapeLen = 64; + string instr_shape = ShapeUtil::HumanString(instr->shape()); + if (instr_shape.length() > kMaxShapeLen) { + instr_shape = + StrCat(tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), + "..."); + } + lines.push_back(instr_shape); - Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", - InstructionId(instruction.get()).c_str(), label.c_str(), - shape.c_str(), NodeColorAttributes(color).c_str()); - - // Create edges from the instruction's operands to the instruction. - int64 operand_number = 0; - for (auto* operand : instruction->operands()) { - string src; - if (operand->opcode() == HloOpcode::kParameter) { - // If operand is a parameter, then select the proper partition (port) in - // the unified parameter node. - src = param_node_name + ":" + InstructionId(operand); - } else { - src = InstructionId(operand); - } - Appendf(&graph_body, "%s -> %s", src.c_str(), - InstructionId(instruction.get()).c_str()); - if (instruction->operand_count() > 1) { - Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]", - operand_number); - } - StrAppend(&graph_body, ";\n"); - ++operand_number; - } - - // Fusion nodes are handled specially because they contain nested - // expressions. - if (instruction->opcode() == HloOpcode::kFusion) { - string cluster_name = - StrCat("cluster_", InstructionId(instruction.get())); - StrAppend(&graph_body, "subgraph ", cluster_name, " {\n"); - StrAppend(&graph_body, - "label=<fused expression>;\nstyle=\"rounded,filled\";\n" - "color=lightgrey;\n"); - StrAppend(&graph_body, InstructionSequenceGraph( - instruction->fused_instructions(), - show_addresses, show_layouts, - intercomputation_edges, hlo_execution_profile), - "}\n"); - string fusion_edge = - StrCat(InstructionId(instruction->fused_expression_root()), " -> ", - InstructionId(instruction.get()), - " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name, - " ];\n"); - intercomputation_edges->push_back(fusion_edge); + if (show_addresses_) { + lines.push_back(Printf("[%p]", instr)); + } + if (show_layouts_ && LayoutUtil::HasLayout(instr->shape())) { + string layout_str; + if (ShapeUtil::IsTuple(instr->shape())) { + // For tuples, emit the full shape because the layout of a tuple is not + // represented in a single Layout field. + layout_str = ShapeUtil::HumanStringWithLayout(instr->shape()); } else { - // Add a dotted edge between the instruction and any computations that the - // instruction calls. - for (const HloComputation* computation : - instruction->called_computations()) { - string cluster_name = StrCat("cluster_", ComputationId(computation)); - string call_edge = Printf( - "%s -> %s [ style=dashed; ltail=%s ];\n", - InstructionId(computation->root_instruction()).c_str(), - InstructionId(instruction.get()).c_str(), cluster_name.c_str()); - intercomputation_edges->push_back(call_edge); - } + layout_str = Join(instr->shape().layout().minor_to_major(), ","); + } + lines.push_back(Printf("layout={%s}", layout_str)); + } + if (profile_ != nullptr) { + double hlo_cycles_executed = profile_->GetProfileResult(*instr); + double total_cycles_executed = + profile_->total_cycles_executed(*instr->parent()); + if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { + lines.push_back( + Printf("%% of cycles executed=%.2f", + 100 * hlo_cycles_executed / total_cycles_executed)); } } - return graph_body; + return Join(lines, "
"); } -// DOT graphs accept a stylesheet as a URL. So naturally, an inline stylesheet -// is a data URI! -// -// We don't perform any escaping on this string, so be careful not to use double -// quotes inside. -static const char* dot_stylesheet = R"( -data:text/css, -@import url(https://fonts.googleapis.com/css?family=Roboto:400,700); -svg text { - font-family: 'Roboto'; - font-size: 12px; +void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { + auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, + int64 operand_num) { + // Fusion nodes' subcomputations are displayed inline, so if 'from' is a + // fusion node and the node's subcomputation is shown, we draw our edge + // starting at the fusion node's root instead of at the fusion node itself. + if (from->opcode() == HloOpcode::kFusion && + filter_.ShowFusionSubcomputation(from)) { + from = from->fused_expression_root(); + } + if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { + return; + } + edge_ids_.insert({{from, to}, next_edge_id_++}); + + string edge_label; + if (instr->operand_count() > 1) { + edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num); + } + const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)"; + edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), + from->name(), to->name(), edge_label)); + }; + + // Add edges from instr's operands to instr. Parameters within fusion + // expressions are handled specially -- we draw an edge from the corresponding + // operand on the fusion node itself to the parameter. + if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { + const HloInstruction* fusion = instr->fusion_instruction(); + add_edge(fusion->operand(instr->parameter_number()), instr, + /*operand_num=*/0); + } else { + for (int64 i = 0; i < instr->operand_count(); ++i) { + add_edge(instr->operand(i), instr, i); + } + } } -)"; -string ComputationToDotGraph(const HloComputation& computation, - const string& label, bool show_addresses, - bool show_layouts, - const HloExecutionProfile* hlo_execution_profile) { - string graph_label = StrCat(label, "
", computation.name()); - if (hlo_execution_profile != nullptr) { - auto cycles = hlo_execution_profile->total_cycles_executed(computation); - Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, - tensorflow::strings::HumanReadableNum(cycles).c_str()); - } - string graph = Printf( - R"(digraph G { -rankdir=TB; -compound=true; -label=<%s>; -labelloc=t; -stylesheet="%s" -)", - graph_label.c_str(), dot_stylesheet); - - // Emit embedded computations as subgraph clusters. - std::vector intercomputation_edges; - for (auto embedded : computation.MakeEmbeddedComputationsList()) { - string graph_body = InstructionSequenceGraph( - embedded->instructions(), show_addresses, show_layouts, - &intercomputation_edges, hlo_execution_profile); - Appendf(&graph, - "subgraph cluster_%s " - "{\nstyle=rounded;label=<%s>;labelloc=t;\n%s}\n", - ComputationId(embedded).c_str(), embedded->name().c_str(), - graph_body.c_str()); - } - StrAppend(&graph, - InstructionSequenceGraph(computation.instructions(), show_addresses, - show_layouts, &intercomputation_edges, - hlo_execution_profile)); - - // Edges between computations (subgraph clusters) must be emitted last for the - // graph to be rendered properly for some reason. - StrAppend(&graph, Join(intercomputation_edges, "\n"), "}\n"); - - return graph; +string HloDotDumper::GetInstructionTrivialComputationStr( + const HloInstruction* instr) { + // called_computations() on a fusion node "inherits" any called computations + // of the fused root, which isn't what we want. Just ignore fusion nodes + // here; they're handled separately. + if (instr->opcode() == HloOpcode::kFusion) { + return ""; + } + + std::vector lines; + for (int64 i = 0; i < instr->called_computations().size(); ++i) { + optional computation_type = + MatchTrivialComputation(instr->called_computations()[i]); + if (!computation_type) { + continue; + } + if (instr->called_computations().size() == 1) { + lines.push_back(Printf("Subcomputation: %s", + HtmlLikeStringSanitize(*computation_type))); + } else { + lines.push_back(Printf("Subcomputation %lld: %s", i, + HtmlLikeStringSanitize(*computation_type))); + } + } + return Join(lines, "
"); } tensorflow::mutex& RendererMutex() { @@ -508,10 +956,9 @@ namespace { class FileGraphRenderer : public GraphRendererInterface { public: - string RenderGraph(const string& graph, GraphKind graph_kind) override { + string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) override { static std::atomic output_num(0); - legacy_flags::HloGraphDumperFlags* flags = - legacy_flags::GetHloGraphDumperFlags(); string file_extension; switch (graph_kind) { case DOT_GRAPH: @@ -522,7 +969,7 @@ class FileGraphRenderer : public GraphRendererInterface { break; } string path = - JoinPath(flags->xla_hlo_dump_graph_path, + JoinPath(debug_options.xla_hlo_graph_path(), StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); auto status = Status::OK(); int fd = mkstemps(&path[0], file_extension.length()); @@ -543,18 +990,118 @@ class FileGraphRenderer : public GraphRendererInterface { } }; +// Gets a NodeFilter that includes roughly all instructions whose distance from +// root is <= radius. +NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { + // First, find the neighborhood of nodes with distance from root <= radius. + // These nodes are our initial set of "normal" nodes. + std::unordered_map nodes; + std::deque> worklist; + worklist.push_back({root, 0}); + while (!worklist.empty()) { + const HloInstruction* instr; + int64 depth; + std::tie(instr, depth) = worklist.front(); + worklist.pop_front(); + + nodes[instr] = kNormalNode; + if (depth == radius) { + continue; + } + + // Traverse into instr's operands. + // + // Don't traverse into tuples' operands unless the tuple is the root. + // Usually a tuple is the bottommost node in the graph, and so its operands + // are not interesting to the graph at hand. + if (instr == root || instr->opcode() != HloOpcode::kTuple) { + for (const HloInstruction* operand : instr->operands()) { + if (!nodes.count(operand)) { + worklist.push_back({operand, depth + 1}); + } + } + } + + // Traverse into instr's users, unless: + // + // - there are a ton of them, in which case they're probably not + // interesting (and anyway, rendering them all would make the graph + // unreadable), or + // - instr is a constant, in which case its users are probably not + // interesting. + if (instr->opcode() == HloOpcode::kConstant) { + continue; + } + constexpr int kMaxUsersToRender = 16; + if (instr->user_count() > kMaxUsersToRender) { + // If we're going to skip this node's users, style it as such. + nodes[instr] = kSomeUsersOmitted; + continue; + } + for (const HloInstruction* user : instr->users()) { + if (!nodes.count(user)) { + worklist.push_back({user, depth + 1}); + } + } + } + + auto is_displayed = [&](const HloInstruction* instr) { + // Constants are displayed inline with their users; they're never omitted. + return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant; + }; + + // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we + // know which nodes will be included in the graph. + for (auto& kv : nodes) { + const HloInstruction* instr = kv.first; + NodeFilterResult& filter_result = kv.second; + const auto& operands = instr->operands(); + + if (std::any_of(operands.begin(), operands.end(), is_displayed) && + !std::all_of(operands.begin(), operands.end(), is_displayed)) { + // Mark nodes with some operands omitted appropriately. + filter_result = kSomeOperandsOmitted; + } else if (!operands.empty() && + std::none_of(operands.begin(), operands.end(), is_displayed)) { + // Mark nodes with *all* operands omitted appropriately. + filter_result = kOmitNodeOperands; + } + + // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their + // users made it into the graph. + if (filter_result == kSomeUsersOmitted && + std::all_of(instr->users().begin(), instr->users().end(), + is_displayed)) { + filter_result = kNormalNode; + } + } + + // Highlight the root node. + nodes[root] = kHighlightNode; + + return NodeFilter([=](const HloInstruction* instr) { + auto it = nodes.find(instr); + if (it != nodes.end()) { + return it->second; + } + // Show all nodes in subcomputations. + if (instr->parent() != root->parent()) { + return kNormalNode; + } + return kHideNode; + }); +} + XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); } // namespace string DumpGraph(const HloComputation& computation, const string& label, - bool show_addresses, bool show_layouts, + const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile) { string graph; string graph_url; - legacy_flags::HloGraphDumperFlags* flags = - legacy_flags::GetHloGraphDumperFlags(); - if (flags->xla_hlo_dump_as_graphdef) { + if (debug_options.xla_hlo_dump_as_graphdef()) { HloTfGraphBuilder builder; TF_CHECK_OK(builder.AddComputation(computation)); CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), @@ -563,18 +1110,37 @@ string DumpGraph(const HloComputation& computation, const string& label, // renderers support rendering GraphDefs. Always dump GraphDefs to files // for now. graph_url = FileGraphRenderer().RenderGraph( - graph, GraphRendererInterface::TF_GRAPHDEF); + graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); } else { - graph = ComputationToDotGraph(computation, label, show_addresses, - show_layouts, hlo_execution_profile); + graph = + HloDotDumper(&computation, label, + /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), + /*show_layouts=*/debug_options.xla_hlo_graph_layout(), + hlo_execution_profile, NodeFilter()) + .Dump(); graph_url = GetGraphRenderer()->RenderGraph( - graph, GraphRendererInterface::DOT_GRAPH); + graph, GraphRendererInterface::DOT_GRAPH, debug_options); } LOG(INFO) << "computation " << computation.name() << " [" << label << "]: " << graph_url; return graph_url; } +string DumpNeighborhoodAround(const HloInstruction& node, int radius) { + auto debug_options = node.GetModule()->config().debug_options(); + string label = + StrCat("Neighborhood of ", radius, " nodes around ", node.name()); + NodeFilter filter = MakeNodeFilter(&node, radius); + string graph = + HloDotDumper(node.parent(), label, + /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), + /*show_layouts=*/debug_options.xla_hlo_graph_layout(), + /*profile=*/nullptr, filter) + .Dump(); + return GetGraphRenderer()->RenderGraph( + graph, GraphRendererInterface::DOT_GRAPH, debug_options); +} + void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix) { Env* env = Env::Default(); @@ -584,6 +1150,30 @@ void DumpText(const HloModule& module, const string& label, do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); string path = JoinPath(directory_path, filename); TF_CHECK_OK(WriteStringToFile(env, path, module.ToString())); + LOG(INFO) << "dumping module '" << module.name() << "' to " << path; +} + +string MaybeDumpHloModule(const HloModule& module, const string& label, + const HloExecutionProfile* profile) { + VLOG(2) << "MaybeDumpHloModule called on module " << module.name(); + string graph_url; + const DebugOptions& debug_options = module.config().debug_options(); + if (!debug_options.xla_generate_hlo_graph().empty() && + RE2::PartialMatch(module.name(), + debug_options.xla_generate_hlo_graph())) { + graph_url = + DumpGraph(*module.entry_computation(), label, debug_options, profile); + } + if (!debug_options.xla_log_hlo_text().empty() && + RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) { + LOG(INFO) << "HLO for module " << module.name(); + LOG(INFO) << "Label: " << label; + XLA_LOG_LINES(2, module.ToString()); + } + if (!debug_options.xla_generate_hlo_text_to().empty()) { + DumpText(module, label, debug_options.xla_generate_hlo_text_to()); + } + return graph_url; } } // namespace hlo_graph_dumper diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 8ed50c38473a6f6dd36603e155285e855ff0c5be..0100d50c050a30a2464b912fcf3688426618513e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" namespace xla { namespace hlo_graph_dumper { @@ -38,16 +39,31 @@ class GraphRendererInterface { // Renders a DOT graph, returning a description of the rendered output // (e.g., a URL) - virtual string RenderGraph(const string& graph, GraphKind graph_kind) = 0; + virtual string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) = 0; }; +// Dump the given HLO module if a dump is requested in its debug options. Based +// on the debug options, either a graph dump, a text dump or both may be +// generated. If a graph dump is generated, the description (e.g. an URL) is +// returned; otherwise an empty string is returned. +string MaybeDumpHloModule(const HloModule& module, const string& label, + const HloExecutionProfile* profile = nullptr); + // Dumps a graph of the computation and returns a description of the rendered // graph (e.g., a URL) based on the renderer. The "best" renderer in the // registry is used. string DumpGraph(const HloComputation& computation, const string& label, - bool show_addresses, bool show_layouts, + const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr); +// Like DumpGraph, but renders only nodes "near" the given node in the graph. +// +// The number of nodes dumped is controlled by the radius parameter, which +// (roughly) corresponds to the max distance a node may be from the primary node +// before it's omitted from the graph. +string DumpNeighborhoodAround(const HloInstruction& node, int radius); + // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. // diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index ea813c98743f7c34a891a3b648a2818f5dada8ec..c11fea09d145815d2142e634d93d44dee6601edc 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -122,6 +122,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: @@ -129,6 +130,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kLogicalNot: case HloOpcode::kNegate: case HloOpcode::kSign: + case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: break; @@ -226,6 +228,19 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateReducePrecision(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); + instruction->AppendOperand(operand); + instruction->exponent_bits_ = exponent_bits; + instruction->mantissa_bits_ = mantissa_bits; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum(const Shape& shape, HloInstruction* operand) { @@ -299,6 +314,12 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); instruction->slice_strides_.assign(strides.begin(), strides.end()); + // For backward compatibility with old serialized computations: if there are + // no strides, assume all strides are 1. + // TODO(b/63317920): remove this code. + if (instruction->slice_strides_.empty()) { + instruction->slice_strides_ = std::vector(start_indices.size(), 1LL); + } return instruction; } @@ -371,6 +392,40 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBatchNormTraining(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, float epsilon, + int64 feature_index) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(scale); + instruction->AppendOperand(offset); + instruction->epsilon_ = epsilon; + instruction->feature_index_ = feature_index; + return instruction; +} + +/* static */ std::unique_ptr +HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, + HloInstruction* scale, HloInstruction* mean, + HloInstruction* variance, + HloInstruction* grad_output, float epsilon, + int64 feature_index) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(scale); + instruction->AppendOperand(mean); + instruction->AppendOperand(variance); + instruction->AppendOperand(grad_output); + instruction->epsilon_ = epsilon; + instruction->feature_index_ = feature_index; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, @@ -505,19 +560,20 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse) { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(instruction_to_fuse->IsFusable()); - + if (GetModule()) { + XLA_VLOG_LINES(1, GetModule()->ToString()); + } HloInstruction* clone = nullptr; - if (fused_instructions_computation_ == nullptr) { + if (called_computations_.empty()) { // New fusion instruction. auto builder = HloComputation::Builder("fused_computation", true); builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); - fused_instructions_computation_ = builder.Build(); + called_computations_.push_back( + CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); clone = fused_expression_root(); clone->parent_fusion_instruction_ = this; } else { - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - clone = fused_instructions_computation_->AddInstruction( + clone = fused_instructions_computation()->AddInstruction( instruction_to_fuse->Clone(/*suffix=*/"")); clone->parent_fusion_instruction_ = this; // instruction_to_fuse is necessarily an operand of the fusion instruction. @@ -528,7 +584,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) != operands_.end()); const std::vector& fused_parameters_ = - fused_instructions_computation_->parameter_instructions(); + fused_instructions_computation()->parameter_instructions(); for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { if (instruction_to_fuse == operands_[operand_num]) { // replace the fused parameter instruction's uses with the clone. @@ -538,7 +594,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Remove the corresponding fused parameter and operand from their // respective vectors. TF_CHECK_OK( - fused_instructions_computation_->RemoveParameter(operand_num)); + fused_instructions_computation()->RemoveParameter(operand_num)); operands_.erase(operands_.begin() + operand_num); break; } @@ -550,7 +606,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Reread the parameters in the computation. const std::vector& fused_parameters_ = - fused_instructions_computation_->parameter_instructions(); + fused_instructions_computation()->parameter_instructions(); // Add each operand of the clone as an operand of the fusion instruction. A // complication is that some clone operands may already be operands of the @@ -583,7 +639,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( CreateParameter(param_no, operand->shape(), param_name); param_instruction->parent_fusion_instruction_ = this; - fused_param = fused_instructions_computation_->AddParameter( + fused_param = fused_instructions_computation()->AddParameter( std::move(param_instruction)); AppendOperand(operand); } @@ -597,7 +653,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( called_computations_.push_back(computation); } } - return clone; } @@ -608,17 +663,15 @@ RandomDistribution HloInstruction::random_distribution() const { void HloInstruction::CheckFusionInstruction() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); const std::list>& fused_instructions_ = - fused_instructions_computation_->instructions(); + fused_instructions_computation()->instructions(); // All instructions owned by this fusion instruction must be fused, and the // parent fusion instruction of the fused instructions must be 'this'. for (auto& instruction : fused_instructions_) { CHECK(instruction->IsFused()); CHECK_EQ(this, instruction->fusion_instruction()); - CHECK_EQ(fused_instructions_computation_.get(), instruction->parent()) + CHECK_EQ(fused_instructions_computation(), instruction->parent()) << instruction->ToString(); } @@ -730,6 +783,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kIsFinite: case HloOpcode::kFloor: @@ -737,6 +791,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kLogicalNot: case HloOpcode::kNegate: case HloOpcode::kSign: + case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); @@ -780,6 +835,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); return CreateConvert(shape, new_operands[0]); + case HloOpcode::kReducePrecision: + CHECK_EQ(new_operands.size(), 1); + return CreateReducePrecision(shape, new_operands[0], exponent_bits_, + mantissa_bits_); case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); return CreateConvolve(shape, new_operands[0], new_operands[1], *window_, @@ -838,18 +897,31 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return CreateWhile(shape, while_condition(), while_body(), new_operands[0]); case HloOpcode::kConstant: - return CreateConstant(LiteralUtil::CloneToUnique(*literal_)); + return CreateConstant(literal_->CloneToUnique()); case HloOpcode::kFusion: return CloneFusionWithNewOperands(shape, new_operands); case HloOpcode::kParameter: return CreateParameter(parameter_number_, shape, parameter_name_); - // Unsupported ops for cloning. + case HloOpcode::kBatchNormTraining: + CHECK_EQ(new_operands.size(), 3); + return CreateBatchNormTraining(shape, new_operands[0], new_operands[1], + new_operands[2], epsilon(), + feature_index()); + case HloOpcode::kInfeed: + CHECK_EQ(new_operands.size(), 0); + return CreateInfeed(shape, infeed_config()); + case HloOpcode::kOutfeed: + CHECK_EQ(new_operands.size(), 1); + return CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); + case HloOpcode::kBatchNormGrad: + CHECK_EQ(new_operands.size(), 5); + return CreateBatchNormGrad(shape, new_operands[0], new_operands[1], + new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); case HloOpcode::kRecv: case HloOpcode::kSend: case HloOpcode::kUpdate: case HloOpcode::kIndex: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -902,8 +974,6 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands) { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(parent() != nullptr); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); auto new_instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); @@ -918,9 +988,9 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( // fused instructions. std::vector new_fused_parameters; const std::vector& fused_parameters_ = - fused_instructions_computation_->parameter_instructions(); + fused_instructions_computation()->parameter_instructions(); const std::list>& fused_instructions_ = - fused_instructions_computation_->instructions(); + fused_instructions_computation()->instructions(); for (HloInstruction* old_fused_parameter : fused_parameters_) { new_fused_instructions.push_back(old_fused_parameter->Clone()); @@ -954,7 +1024,7 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( } new_instruction->fusion_kind_ = fusion_kind_; auto computation_builder = HloComputation::Builder( - fused_instructions_computation_->name() + ".clone", true); + fused_instructions_computation()->name() + ".clone", true); // We iterated the fusion instructions in reverse post order which means // that we must reverse our new list of fusion instructions. for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); @@ -963,8 +1033,10 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); } auto fused_root_ = fused_expression_root(); - new_instruction->fused_instructions_computation_ = - computation_builder.Build(FindOrDie(old_to_new, fused_root_)); + new_instruction->called_computations_.push_back( + CHECK_NOTNULL(GetModule()) + ->AddEmbeddedComputation( + computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); new_instruction->set_parent(parent()); new_instruction->CheckFusionInstruction(); return new_instruction; @@ -1041,7 +1113,7 @@ Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) { auto pred_it = std::find(instruction->control_predecessors_.begin(), instruction->control_predecessors_.end(), this); TF_RET_CHECK(pred_it != instruction->control_predecessors_.end()); - instruction->control_predecessors_.erase(succ_it); + instruction->control_predecessors_.erase(pred_it); return Status::OK(); } @@ -1099,6 +1171,7 @@ bool HloInstruction::Identical( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: case HloOpcode::kDot: @@ -1123,6 +1196,7 @@ bool HloInstruction::Identical( case HloOpcode::kRemainder: case HloOpcode::kSelect: case HloOpcode::kSign: + case HloOpcode::kSin: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: @@ -1141,15 +1215,25 @@ bool HloInstruction::Identical( // different HloComputations. ShapeUtil::Compatible(shape(), other.shape()); + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormGrad: + return feature_index() == other.feature_index() && + epsilon() == other.epsilon(); + // A constant is defined by the value in the literal. case HloOpcode::kConstant: - return LiteralUtil::Equal(literal(), other.literal()); + return literal().Equal(other.literal()); // A convert result is determined by the primitive type that the operand is // converted into. case HloOpcode::kConvert: return shape().element_type() == other.shape().element_type(); + // A reduce-precision operation is determined by the bit sizes. + case HloOpcode::kReducePrecision: + return exponent_bits() == other.exponent_bits() && + mantissa_bits() == other.mantissa_bits(); + // Convolution has a window and dimensions. case HloOpcode::kConvolution: return protobuf_util::ProtobufEquals(window(), other.window()) && @@ -1438,10 +1522,10 @@ string HloInstruction::ToString(bool compact_operands, string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. - if (ShapeUtil::ElementsIn(shape()) <= 10) { - // LiteralUtil::ToString emits multidimensional arrays over multiple + if (!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) { + // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = LiteralUtil::ToString(literal()); + string tmp = literal().ToString(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); std::vector v = tensorflow::str_util::Split(tmp, ' '); bool first = true; @@ -1455,7 +1539,7 @@ string HloInstruction::ToString(bool compact_operands, first = false; } } else { - // Do not show large constants. + // Do not show large constants or tuples. operands = "{...}"; } } else if (opcode() == HloOpcode::kParameter) { @@ -1565,7 +1649,7 @@ HloInstructionProto HloInstruction::ToProto() const { case HloOpcode::kFusion: { HloComputationProto* proto_fused_computation = proto.mutable_fused_instructions_computation(); - proto_fused_computation->set_name(FullyQualifiedName()); + proto_fused_computation->set_name(name()); // Fill in fused instructions. Note that fused_instructions() returns in // reverse post-order (i.e. root first), so we reverse to get post-order. @@ -1629,6 +1713,8 @@ string HloInstruction::ToCategory() const { case FusionKind::kConvBackwardFilter: case FusionKind::kConvBackwardInput: return "convolution fusion"; + case FusionKind::kCustom: + return "custom fusion"; } } @@ -1639,14 +1725,6 @@ string HloInstruction::ToCategory() const { return HloOpcodeString(opcode()); } -string HloInstruction::FullyQualifiedName() const { - if (IsFused()) { - return StrCat(fusion_instruction()->parent()->name(), - "::", fusion_instruction()->name(), "::", name_); - } - return StrCat(parent_->name(), "::", name_); -} - HloInstruction* HloInstruction::tracing() const { return trace_instruction_; } void HloInstruction::set_tracing(HloInstruction* trace_instruction) { @@ -1689,7 +1767,10 @@ bool HloInstruction::IsFusable() const { HloComputation* HloInstruction::fused_instructions_computation() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation_.get(); + CHECK(!called_computations_.empty()); + auto* fused_instructions_computation = called_computations_.front(); + CHECK(fused_instructions_computation->IsFusionComputation()); + return fused_instructions_computation; } HloInstruction* HloInstruction::fusion_instruction() const { @@ -1699,32 +1780,24 @@ HloInstruction* HloInstruction::fusion_instruction() const { HloInstruction* HloInstruction::fused_expression_root() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - return fused_instructions_computation_->root_instruction(); + return fused_instructions_computation()->root_instruction(); } HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - return fused_instructions_computation_->parameter_instruction( + return fused_instructions_computation()->parameter_instruction( parameter_number); } const std::vector& HloInstruction::fused_parameters() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - return fused_instructions_computation_->parameter_instructions(); + return fused_instructions_computation()->parameter_instructions(); } const std::list>& HloInstruction::fused_instructions() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - return fused_instructions_computation_->instructions(); + return fused_instructions_computation()->instructions(); } HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) @@ -1736,6 +1809,10 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { switch (opcode_) { case HloOpcode::kAbs: return visitor->HandleAbs(this, operands_[0]); + case HloOpcode::kBatchNormTraining: + return visitor->HandleBatchNormTraining(this); + case HloOpcode::kBatchNormGrad: + return visitor->HandleBatchNormGrad(this); case HloOpcode::kSign: return visitor->HandleSign(this, operands_[0]); case HloOpcode::kConstant: @@ -1758,9 +1835,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kSubtract: return visitor->HandleSubtract(this, operands_[0], operands_[1]); case HloOpcode::kMaximum: - return visitor->HandleMaximum(this, operands_[0], operands_[1]); + return visitor->HandleMaximum(this); case HloOpcode::kMinimum: - return visitor->HandleMinimum(this, operands_[0], operands_[1]); + return visitor->HandleMinimum(this); case HloOpcode::kLogicalAnd: return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]); case HloOpcode::kLogicalOr: @@ -1768,9 +1845,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kConcatenate: return visitor->HandleConcatenate(this, operands_); case HloOpcode::kConvert: - return visitor->HandleConvert(this, operands_[0]); + return visitor->HandleConvert(this); case HloOpcode::kCopy: - return visitor->HandleCopy(this, operands_[0]); + return visitor->HandleCopy(this); case HloOpcode::kMultiply: return visitor->HandleMultiply(this, operands_[0], operands_[1]); case HloOpcode::kDot: @@ -1814,6 +1891,10 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleLog(this, operands_[0]); case HloOpcode::kTanh: return visitor->HandleTanh(this, operands_[0]); + case HloOpcode::kCos: + return visitor->HandleCos(this, operands_[0]); + case HloOpcode::kSin: + return visitor->HandleSin(this, operands_[0]); case HloOpcode::kIsFinite: return visitor->HandleIsFinite(this, operands_[0]); case HloOpcode::kLogicalNot: @@ -1830,6 +1911,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleTranspose(this); case HloOpcode::kReverse: return visitor->HandleReverse(this, operands_[0]); + case HloOpcode::kReducePrecision: + return visitor->HandleReducePrecision(this); case HloOpcode::kSlice: return visitor->HandleSlice(this, operands_[0]); case HloOpcode::kDynamicSlice: @@ -1868,72 +1951,90 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { HloOpcodeString(opcode_).c_str()); } -Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order, - bool ignore_control_predecessors) { - // Do not visit this HLO node again if it is already visited. - if (visitor->DidVisit(*this)) { - VLOG(3) << "Not visiting HLO " << name() << " as it was already visited."; - return Status::OK(); +static Status PushDFSChild(DfsHloVisitor* visitor, + std::vector* dfs_stack, + HloInstruction* parent, HloInstruction* child) { + switch (visitor->GetVisitState(*child)) { + case DfsHloVisitor::kVisiting: + return FailedPrecondition( + "A cycle is detected while visiting instruction %s", + parent->ToString().c_str()); + + case DfsHloVisitor::kVisited: + VLOG(3) << "Not visiting HLO " << child->name() + << " as it was already visited."; + return Status::OK(); + + case DfsHloVisitor::kNotVisited: + dfs_stack->push_back(child); + return Status::OK(); } +} - // If the instruction is in the visiting state, it means a cycle. - if (visitor->IsVisiting(*this)) { - return FailedPrecondition( - "A cycle is detected while visiting instruction %s", - ToString().c_str()); - } - visitor->SetVisiting(*this); - - // Sort operands, if an ordering was provided. 'temp_sorted_operands' must - // live at this scope, since 'operands' will point to it if the operands are - // sorted. The purpose of the 'operands' pointer is to avoid copying the - // operands in the common case where the operands are not sorted. - std::vector* operands = &operands_; - std::vector temp_sorted_operands; - if (operand_order != nullptr) { - temp_sorted_operands = operands_; - std::sort(temp_sorted_operands.begin(), temp_sorted_operands.end(), - *operand_order); - operands = &temp_sorted_operands; - } - for (HloInstruction* operand : *operands) { - VLOG(3) << "Going to visit HLO " << operand->name() << " as operand of HLO " - << name(); - TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor, operand_order, - ignore_control_predecessors)); - } - - if (!ignore_control_predecessors) { - // This uses the same pointer/vector sorting to avoid extra copies as above. - std::vector* predecessors = &control_predecessors_; - std::vector temp_sorted_predecessors; - if (operand_order != nullptr) { - temp_sorted_predecessors = control_predecessors_; - std::sort(temp_sorted_predecessors.begin(), - temp_sorted_predecessors.end(), *operand_order); - predecessors = &temp_sorted_predecessors; +static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, + const HloInstruction::CompareFunction* operand_order, + bool ignore_control_predecessors) { + std::vector dfs_stack; + dfs_stack.push_back(root); + + do { + DCHECK(!dfs_stack.empty()); + + HloInstruction* current_node = dfs_stack.back(); + DfsHloVisitor::VisitState visit_state = + visitor->GetVisitState(*current_node); + if (visit_state == DfsHloVisitor::kVisited) { + dfs_stack.pop_back(); + VLOG(3) << "Not visiting HLO " << current_node->name() + << " as it was already visited."; + continue; } - for (HloInstruction* control_predecessor : *predecessors) { - VLOG(3) << "Going to visit HLO " << control_predecessor->name() - << " as a control predecessor of HLO " << name(); - TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal( - visitor, operand_order, ignore_control_predecessors)); + + if (visit_state == DfsHloVisitor::kVisiting) { + dfs_stack.pop_back(); + + TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); + VLOG(2) << "Visiting HLO " << current_node->name(); + TF_RETURN_IF_ERROR(current_node->Visit(visitor)); + visitor->SetVisited(*current_node); + TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); + continue; + } + + visitor->SetVisiting(*current_node); + + const size_t old_dfs_stack_size = dfs_stack.size(); + + for (HloInstruction* child : current_node->operands()) { + TF_RETURN_IF_ERROR( + PushDFSChild(visitor, &dfs_stack, current_node, child)); + } + + if (!ignore_control_predecessors) { + for (HloInstruction* child : current_node->control_predecessors()) { + TF_RETURN_IF_ERROR( + PushDFSChild(visitor, &dfs_stack, current_node, child)); + } } - } - TF_RETURN_IF_ERROR(visitor->Preprocess(this)); - VLOG(2) << "Visiting HLO " << name(); - TF_RETURN_IF_ERROR(Visit(visitor)); - visitor->SetVisited(*this); - return visitor->Postprocess(this); + if (operand_order != nullptr) { + std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(), + *operand_order); + } + + // This makes the traversal order the same as what you'd expect + // out of a recursive algorithm. + std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end()); + } while (!dfs_stack.empty()); + + return Status::OK(); } Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, bool ignore_control_predecessors) { - VLOG(2) << "HloInstruction::Accept(" << name() << ")"; + VLOG(3) << "HloInstruction::Accept(" << name() << ")"; TF_RETURN_IF_ERROR( - AcceptInternal(visitor, nullptr, ignore_control_predecessors)); + PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } @@ -1944,11 +2045,14 @@ Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; - TF_RETURN_IF_ERROR(AcceptInternal(visitor, &operand_order, - /*ignore_control_predecessors=*/false)); + TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order, + /*ignore_control_predecessors=*/false)); if (call_finish_visit) { + VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT"; TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); + VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT"; } + VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT"; return Status::OK(); } @@ -2060,13 +2164,16 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCeil: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLogicalNot: case HloOpcode::kNegate: + case HloOpcode::kReducePrecision: case HloOpcode::kSign: + case HloOpcode::kSin: case HloOpcode::kTanh: return true; @@ -2274,6 +2381,8 @@ string ToString(HloInstruction::FusionKind kind) { return "kConvBackwardFilter"; case HloInstruction::FusionKind::kConvBackwardInput: return "kConvBackwardInput"; + case HloInstruction::FusionKind::kCustom: + return "kCustom"; } } @@ -2345,7 +2454,13 @@ HloModule* HloInstruction::GetModule() const { } void HloInstruction::UniquifyName(NameUniquer* name_uniquer) { + string parent_str = parent() == nullptr ? "noparent" : parent()->name(); name_ = name_uniquer->GetUniqueName(name_); } +void HloInstruction::set_outer_dimension_partitions( + const std::vector& outer_dimension_partitions) { + outer_dimension_partitions_ = outer_dimension_partitions; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index c7cd729934b2a52d95b32b4ba5f5c84dc087cfd4..3c188ec83f3bdcec7c40835794d1694f883388a0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -63,6 +63,9 @@ class HloInstruction { kTransposeDot, // Fused into a dot with transposed operands. kConvBackwardFilter, // Fused into a backward filter convolution. kConvBackwardInput, // Fused into a backward input convolution. + + kCustom, // Custom category for backend-specific fusions that + // do not match any of the more specific ones. }; ~HloInstruction(); @@ -131,6 +134,13 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates a reduce-precision op, where operand is the data to reduce in + // precision, and exponent_bits and mantissa_bits describe the precision to + // reduce it to. + static std::unique_ptr CreateReducePrecision( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits); + // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( const Shape& shape, HloInstruction* operand); @@ -209,6 +219,17 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); + // Creates a batch-norm-training instruction. + static std::unique_ptr CreateBatchNormTraining( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, float epsilon, int64 feature_index); + + // Creates a batch-norm-grad instruction. + static std::unique_ptr CreateBatchNormGrad( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, + HloInstruction* grad_output, float epsilon, int64 feature_index); + // Creates a scatter computation that scatters the `source` array to the // selected indices of each window. static std::unique_ptr CreateSelectAndScatter( @@ -510,11 +531,6 @@ class HloInstruction { // or "elementwise". string ToCategory() const; - // Returns the string concatenation of parent name and this instructions - // name. This name is guaranteed to be unique among all instructions in the - // HloModule. - string FullyQualifiedName() const; - // Returns a logging instruction, if the output of this instruction is logged. // // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace @@ -528,6 +544,18 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv int64 channel_id() const { return channel_id_; } + // Returns feature_index field associated with the instruction. The index + // represents the index of the feature dimension. + // + // Precondition: opcode() == HloOpcode::kBatchNormTraining + int64 feature_index() const { return feature_index_; } + + // Returns a epsilon value associated with the instruction. The is a small + // number added to the variance to avoid divide-by-zero error. + // + // Precondition: opcode() == HloOpcode::kBatchNormTraining + float epsilon() const { return epsilon_; } + // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. @@ -642,7 +670,7 @@ class HloInstruction { // Returns the stride in the given dimension for a slice node. // // Precondition: opcode() == HloOpcode::kSlice - int64 slice_stride(int64 dimension) const { + int64 slice_strides(int64 dimension) const { CHECK_EQ(HloOpcode::kSlice, opcode_); return slice_strides_[dimension]; } @@ -661,6 +689,22 @@ class HloInstruction { return dynamic_slice_sizes_; } + // Returns the number of exponent bits for a reduce-precision node. + // + // Precondition: opcode() == HloOpcode::kReducePrecision + int32 exponent_bits() const { + CHECK_EQ(HloOpcode::kReducePrecision, opcode_); + return exponent_bits_; + } + + // Returns the number of mantissa bits for a reduce-precision node. + // + // Precondition: opcode() == HloOpcode::kReducePrecision + int32 mantissa_bits() const { + CHECK_EQ(HloOpcode::kReducePrecision, opcode_); + return mantissa_bits_; + } + // Returns data on the window in a windowed operation such as // convolution. const Window& window() const { @@ -708,6 +752,16 @@ class HloInstruction { return called_computations_; } + // Replaces all called computations based on a map function. This is needed + // when we clone hlo_computations and want to let the instructions to point + // to the newly cloned nodes. + void ReplaceCalledComputations( + std::function map_function) { + for (int64 i = 0; i < called_computations_.size(); ++i) { + called_computations_[i] = map_function(called_computations_[i]); + } + } + // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, // after performing necessary implicit broadcast @@ -742,9 +796,9 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns the opcode string for this instruction. Compared with - // HloOpcodeString method, this wrapper dumps additional information - // such as fusion kind. + // Returns the opcode string for this instruction. This is the result from + // HloOpcodeString plus, for fusion nodes, the fusion kind, separated by a + // ':'. string ExtendedOpcodeStr() const; // Returns a string identifier for this instruction. If no string identifier @@ -782,6 +836,17 @@ class HloInstruction { parent_fusion_instruction_ = fusion_instruction; } + // Get/Set the number of partitions per outer dimension (in order, starting + // with outer-most dimension first). Currently used by the parallel cpu + // backend to partition HLOs into parallel tasks. + // TODO(b/62783254) Replace these methods with a more general way to + // annotate HLOs with backend-specific information. + const std::vector& outer_dimension_partitions() const { + return outer_dimension_partitions_; + } + void set_outer_dimension_partitions( + const std::vector& outer_dimension_partitions); + private: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; @@ -818,12 +883,6 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands); - // Inner DFS traversal function -- this function being called (rather than - // Accept above) allows us to distinguish the root of the traversal. - Status AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order, - bool ignore_control_predecessors); - // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; @@ -864,6 +923,10 @@ class HloInstruction { std::vector slice_limits_; std::vector slice_strides_; + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_; + int32 mantissa_bits_; + // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). std::vector dynamic_slice_sizes_; @@ -872,10 +935,6 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr padding_config_; - // The computation that stores of instructions fused into this fusion - // instruction. Only set for fusion instructions. - std::unique_ptr fused_instructions_computation_; - // If this instruction is fused into a fusion instruction, this field points // to the fusion instruction. HloInstruction* parent_fusion_instruction_ = nullptr; @@ -934,6 +993,14 @@ class HloInstruction { // Only present for kRng. RandomDistribution distribution_; + // A small float number added to the variance to avoid divide-by-zero error. + // Only present for kBatchNormTraining. + float epsilon_; + + // An integer value representing the index of the feature dimension. + // Only present for kBatchNormTraining. + int64 feature_index_; + // Represents a unique identifier for each Send/Recv instruction pair. // Only present for kSend or kRecv. int64 channel_id_ = -1; @@ -950,6 +1017,10 @@ class HloInstruction { // Metadata for debugging. OpMetadata metadata_; + // The number of partitions per outer dimension (listed in order from + // outer-most dimension first). + std::vector outer_dimension_partitions_; + TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index bcf81cd8ddf63eff2f1df9c6c797588eee42f6b5..ced8417fcef9c009f8f3706ef4707bf0835faec2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -232,7 +232,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { // ------- auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0"); auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1"); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0.get(), c0.get()); auto addright = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, @@ -271,7 +271,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { // ------- auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0"); auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1"); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto neg1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0.get()); auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0.get(), neg1.get()); @@ -307,7 +307,7 @@ TEST_F(HloInstructionTest, TrivialMap) { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); auto add_f32 = builder.Build(); @@ -349,9 +349,8 @@ TEST_F(HloInstructionTest, TrivialReduce) { // Builds a parameter and an initial value and feeds them to the reduce. auto param0 = HloInstruction::CreateParameter(0, f32a100x10, ""); - auto const0 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto const0 = HloInstruction::CreateConstant(Literal::CreateR0(0.0f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto reduce = HloInstruction::CreateReduce(f32v100, param0.get(), const0.get(), /*dimensions_to_reduce=*/{1}, add_f32.get()); @@ -558,78 +557,110 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { } TEST_F(HloInstructionTest, SingletonFusionOp) { + HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single unary operation. - auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); - auto exp = - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); - - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, exp.get()); - - EXPECT_THAT(fusion->operands(), ElementsAre(constant.get())); - EXPECT_THAT(constant->users(), UnorderedElementsAre(fusion.get(), exp.get())); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {exp}, HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(fusion->operands(), ElementsAre(constant)); + EXPECT_THAT(constant->users(), ElementsAre(fusion)); } TEST_F(HloInstructionTest, BinaryFusionOp) { + HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single binary operation. - auto constant1 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); - auto constant2 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f)); - auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, - constant1.get(), constant2.get()); - - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, add.get()); - - EXPECT_THAT(fusion->operands(), - ElementsAre(constant1.get(), constant2.get())); - EXPECT_THAT(constant1->users(), - UnorderedElementsAre(fusion.get(), add.get())); - EXPECT_THAT(constant2->users(), - UnorderedElementsAre(fusion.get(), add.get())); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.1f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {add}, HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2)); + EXPECT_THAT(constant1->users(), ElementsAre(fusion)); + EXPECT_THAT(constant2->users(), ElementsAre(fusion)); } TEST_F(HloInstructionTest, ChainFusionOp) { + HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. - auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); - auto exp1 = - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); - auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); - auto exp3 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2.get()); - - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, exp3.get()); - fusion->FuseInstruction(exp2.get()); - fusion->FuseInstruction(exp1.get()); - - EXPECT_THAT(fusion->operands(), ElementsAre(constant.get())); - EXPECT_THAT(constant->users(), - UnorderedElementsAre(fusion.get(), exp1.get())); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1)); + auto exp3 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(fusion->operands(), ElementsAre(constant)); + EXPECT_THAT(constant->users(), ElementsAre(fusion)); } TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { + HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. - auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); - auto exp1 = - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); - auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1)); OpMetadata metadata; metadata.set_op_name("tf_op"); exp1->set_metadata(metadata); exp2->set_metadata(metadata); - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, exp2.get()); - auto* fused = fusion->FuseInstruction(exp1.get()); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {exp2, exp1}, HloInstruction::FusionKind::kLoop); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); - EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata())); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + metadata, fusion->fused_expression_root()->metadata())); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + metadata, fusion->fused_expression_root()->operand(0)->metadata())); +} + +TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { + HloComputation::Builder builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({ + {1, 2}, + {3, 4}, + }))); + auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); + auto outfeed10 = builder.AddInstruction( + HloInstruction::CreateOutfeed(shape10, constant, "")); + auto outfeed01 = builder.AddInstruction( + HloInstruction::CreateOutfeed(shape01, constant, "")); + + auto clone01 = builder.AddInstruction(outfeed01->Clone()); + auto clone10 = builder.AddInstruction(outfeed10->Clone()); + + EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01)); + EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10)); } TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { + HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -643,33 +674,36 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { std::unique_ptr computation_x = make_map_computation(); std::unique_ptr computation_y = make_map_computation(); - auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); - auto map_1_x = - HloInstruction::CreateMap(scalar_shape, {constant.get()}, - computation_x.get(), /*static_operands=*/{}); - auto map_2_x = - HloInstruction::CreateMap(scalar_shape, {map_1_x.get()}, - computation_x.get(), /*static_operands=*/{}); - auto map_3_y = - HloInstruction::CreateMap(scalar_shape, {map_2_x.get()}, - computation_y.get(), /*static_operands=*/{}); - - auto fusion = HloInstruction::CreateFusion( - scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get()); - - EXPECT_THAT(fusion->called_computations(), ElementsAre(computation_y.get())); - - fusion->FuseInstruction(map_2_x.get()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap( + scalar_shape, {constant}, computation_x.get(), /*static_operands=*/{})); + auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap( + scalar_shape, {map_1_x}, computation_x.get(), /*static_operands=*/{})); + auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( + scalar_shape, {map_2_x}, computation_y.get(), /*static_operands=*/{})); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {map_3_y}, HloInstruction::FusionKind::kLoop); + auto* fused_computation = fusion->fused_instructions_computation(); EXPECT_THAT(fusion->called_computations(), - ElementsAre(computation_y.get(), computation_x.get())); + ElementsAre(fused_computation, computation_y.get())); - fusion->FuseInstruction(map_1_x.get()); - EXPECT_THAT(fusion->called_computations(), - ElementsAre(computation_y.get(), computation_x.get())); + fusion->FuseInstruction(map_2_x); + EXPECT_THAT( + fusion->called_computations(), + ElementsAre(fused_computation, computation_y.get(), computation_x.get())); + + fusion->FuseInstruction(map_1_x); + EXPECT_THAT( + fusion->called_computations(), + ElementsAre(fused_computation, computation_y.get(), computation_x.get())); } TEST_F(HloInstructionTest, ComplexFusionOp) { + HloComputation::Builder builder(TestName()); // Fuse all instructions in complicated expression: // // add = Add(C1, C2) @@ -681,35 +715,35 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // // Notable complexities are repeated operands in a same instruction, different // shapes, use of value in different expressions. - auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); - auto c2 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.1f)); - auto c3 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(9.0f)); - - auto add = - HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get()); - auto clamp = HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, - c2.get(), add.get(), add.get()); - auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add.get()); - auto mul = HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, - exp.get(), c3.get()); - auto sub = HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, - mul.get(), clamp.get()); + auto c1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto c2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.1f))); + auto c3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(9.0f))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2)); + auto clamp = builder.AddInstruction( + HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp)); auto tuple = - HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()}); + builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, tuple.get()); - fusion->FuseInstruction(sub.get()); - fusion->FuseInstruction(mul.get()); - fusion->FuseInstruction(exp.get()); - fusion->FuseInstruction(clamp.get()); - fusion->FuseInstruction(add.get()); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); // Operands in the fusion instruction's operands() vector should be in the // order in which their users were added fused. - EXPECT_THAT(fusion->operands(), ElementsAre(c1.get(), c3.get(), c2.get())); - EXPECT_THAT(c1->users(), - UnorderedElementsAre(add.get(), tuple.get(), fusion.get())); + EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2)); + EXPECT_THAT(c1->users(), ElementsAre(fusion)); } // Convenience function for comparing two HloInstructions inside of @@ -732,11 +766,11 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { // Create a set of random constant operands to use below. Make them matrices // so dimensions are interesting. auto operand1 = HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); auto operand2 = HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); - auto vector_operand = HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0, 123.0})); + Literal::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); + auto vector_operand = + HloInstruction::CreateConstant(Literal::CreateR1({42.0, 123.0})); Shape shape = operand1->shape(); // Convenient short names for the operands. @@ -865,7 +899,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - auto computation = builder.Build(); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -907,7 +942,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - auto computation = builder.Build(); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -946,7 +982,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); - auto computation = builder.Build(); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 141251011cc0b4205b6069ff90415492ead9f7a9..79f17bbb6bd9bfc0c6ed48c68599ef51fbd27af8 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -95,6 +95,7 @@ HLO_MATCHER(Parameter); HLO_MATCHER(Power); HLO_MATCHER(Recv); HLO_MATCHER(Reduce); +HLO_MATCHER(ReducePrecision); HLO_MATCHER(ReduceWindow); HLO_MATCHER(Remainder); HLO_MATCHER(Reshape); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 22ef9c590bcf63a4e0c60931f771455601b0c019..da6f1d77ecb82ddbce11ca43c184ce0552b757fa 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -37,19 +37,17 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(name), config_(config), - entry_computation_(nullptr), has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle), - computation_name_uniquer_(/*separator=*/".") {} + entry_computation_handle_(entry_computation_handle) {} -HloModule::HloModule(const string& name) - : name_(name), - entry_computation_(nullptr), - computation_name_uniquer_(/*separator=*/".") {} +HloModule::HloModule(const string& name) : name_(name) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation) { computation->UniquifyName(&computation_name_uniquer_); + for (auto& instruction : computation->instructions()) { + instruction->UniquifyName(&instruction_name_uniquer_); + } computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -301,6 +299,36 @@ std::list HloModule::MakeComputationPostOrder() const { return post_order; } +std::unique_ptr HloModule::Clone(const string& suffix) { + VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; + auto module = MakeUnique(name_ + "-" + suffix); + module->config_ = config_; + module->entry_computation_handle_ = entry_computation_handle_; + module->has_entry_computation_handle_ = has_entry_computation_handle_; + + std::unordered_map clone_map; + for (auto& computation : computations_) { + auto cloned_computation = computation->Clone(suffix); + InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); + + if (entry_computation_ == computation.get()) { + module->AddEntryComputation(std::move(cloned_computation)); + } else { + module->AddEmbeddedComputation(std::move(cloned_computation)); + } + } + + for (auto& cloned_computation : module->computations_) { + for (auto& instruction : cloned_computation->instructions()) { + // Rewrite instruction's called_computation to point to the cloned + // computations. + instruction->ReplaceCalledComputations( + [&](HloComputation* hlo) { return FindOrDie(clone_map, hlo); }); + } + } + return module; +} + uint64 HloModule::RandomNew64() const { tensorflow::mutex_lock l(rng_mutex_); return rng_(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 4b14b4fd62a460ede0639e4417507ff2af02abd6..ae8ec02fbd1a59fa1f4a4a6160de6db0c033c4b1 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -75,6 +75,9 @@ class HloModule { const string& name() const { return name_; } + // Returns a deep copy of this module including all computations. + std::unique_ptr Clone(const string& suffix = "clone"); + // Return a pointer to the entry computation of the module.. HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); @@ -121,13 +124,16 @@ class HloModule { return computation_name_uniquer_.GetUniqueName(prefix); } + // Returns the NameUniquer for uniquing instruction names in this module. + NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation); const string name_; HloModuleConfig config_; - HloComputation* entry_computation_; + HloComputation* entry_computation_ = nullptr; std::vector> computations_; // Random number generator engine to use when generating random numbers per @@ -141,8 +147,10 @@ class HloModule { bool has_entry_computation_handle_ = false; VersionedComputationHandle entry_computation_handle_; - // Unique name generator for computation names, which are unique per module. - NameUniquer computation_name_uniquer_; + // Unique name generator for computation and instruction names, which are + // unique per module. + NameUniquer computation_name_uniquer_{/*separator=*/"."}; + NameUniquer instruction_name_uniquer_{/*separator=*/"."}; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index a2235a268235860a633fdc5f26c5127574a9487c..8974deb530c2e4561b5ab57f43c65fd525db3617 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -58,6 +58,10 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::replica_count=", replica_count()); } StrAppend(&key, debug_options_.DebugString()); + if (intra_op_parallelism_threads() > 0) { + StrAppend(&key, "::intra_op_parallelism_threads=", + intra_op_parallelism_threads()); + } return key; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index ee32ab9bc4b5dd406d0dd9b6dfff52f852883dd9..2299200b5be969c065fded840709a3d6034efe47 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -92,6 +92,15 @@ class HloModuleConfig { debug_options_ = debug_options; } + // Sets/returns the number of intra op threads for this module. + void set_intra_op_parallelism_threads( + const int intra_op_parallelism_threads) { + intra_op_parallelism_threads_ = intra_op_parallelism_threads; + } + int64 intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -116,6 +125,10 @@ class HloModuleConfig { // The number of replicas to compile this binary for. int64 replica_count_ = 1; + // The target maximum parallelism at which to partition HLOs for parallel + // execution on the CPU backend. + int64 intra_op_parallelism_threads_ = -1; + DebugOptions debug_options_; }; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 870bc729aec98a2959de5aa322850898502394ad..56dc5632035c625445018becfd25d69557e6232a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -38,7 +38,7 @@ class HloModuleTest : public HloTestBase { std::unique_ptr CreateConstantComputation() { auto builder = HloComputation::Builder("Constant"); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); return builder.Build(); } @@ -81,6 +81,30 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { EXPECT_EQ(computation2->name(), "Constant.1"); } +TEST_F(HloModuleTest, CloneTest) { + // Create and copy a module with a diamond call graph of computations. + auto module = CreateNewModule(); + auto computation1 = + module->AddEmbeddedComputation(CreateConstantComputation()); + auto computation2 = + module->AddEmbeddedComputation(CreateCallComputation({computation1})); + auto computation3 = + module->AddEmbeddedComputation(CreateCallComputation({computation1})); + module->AddEntryComputation( + CreateCallComputation({computation2, computation3})); + + auto post_order = module->MakeComputationPostOrder(); + auto cloned_module = module->Clone("copy"); + auto post_order_copied = cloned_module->MakeComputationPostOrder(); + + EXPECT_EQ(post_order.size(), post_order_copied.size()); + for (auto origin = post_order.begin(), copied = post_order_copied.begin(); + origin != post_order.end() && copied != post_order_copied.end(); + ++origin, ++copied) { + EXPECT_EQ((*origin)->name() + "copy", (*copied)->name()); + } +} + TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index ceb0cdaa3169bb57e4ebb61ac1b2ea41f1ef7995..3888f757adaf2e51c598a08f7464688d162595a4 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -19,11 +19,22 @@ limitations under the License. namespace xla { string HloOpcodeString(HloOpcode opcode) { + // Note: Do not use ':' in opcode strings. It is used as a special character + // in these places: + // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to + // separate the opcode from the fusion kind + // - In fully qualified names (HloInstruction::FullyQualifiedName()), to + // separate the qualifiers (name of the computation and potentially the + // fusion instruction) from the name switch (opcode) { case HloOpcode::kAbs: return "abs"; case HloOpcode::kAdd: return "add"; + case HloOpcode::kBatchNormTraining: + return "batch-norm-training"; + case HloOpcode::kBatchNormGrad: + return "batch-norm-grad"; case HloOpcode::kBitcast: return "bitcast"; case HloOpcode::kBroadcast: @@ -40,6 +51,8 @@ string HloOpcodeString(HloOpcode opcode) { return "convert"; case HloOpcode::kConvolution: return "convolution"; + case HloOpcode::kCos: + return "cosine"; case HloOpcode::kCrossReplicaSum: return "cross-replica-sum"; case HloOpcode::kCustomCall: @@ -112,6 +125,8 @@ string HloOpcodeString(HloOpcode opcode) { return "recv"; case HloOpcode::kReduce: return "reduce"; + case HloOpcode::kReducePrecision: + return "reduce-precision"; case HloOpcode::kReduceWindow: return "reduce-window"; case HloOpcode::kRemainder: @@ -130,6 +145,8 @@ string HloOpcodeString(HloOpcode opcode) { return "send"; case HloOpcode::kSign: return "sign"; + case HloOpcode::kSin: + return "sine"; case HloOpcode::kSlice: return "slice"; case HloOpcode::kSort: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e2cdbfdfa7a4b5509dccf9a83ffbd799f9ab1374..8a6376b2d1c3d4fcdb4cbcb40cd56c1f9db9ec8e 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -30,6 +30,8 @@ namespace xla { enum class HloOpcode { kAbs, kAdd, + kBatchNormTraining, + kBatchNormGrad, kBitcast, kBroadcast, kCall, @@ -40,6 +42,7 @@ enum class HloOpcode { kConvert, kConvolution, kCopy, + kCos, kCrossReplicaSum, kCustomCall, kDivide, @@ -74,6 +77,7 @@ enum class HloOpcode { kPower, kRecv, kReduce, + kReducePrecision, kReduceWindow, kRemainder, kReshape, @@ -83,6 +87,7 @@ enum class HloOpcode { kSelectAndScatter, kSend, kSign, + kSin, kSlice, kSort, kSubtract, @@ -107,6 +112,11 @@ bool HloOpcodeIsComparison(HloOpcode opcode); // Returns true iff the given opcode has variadic operands. bool HloOpcodeIsVariadic(HloOpcode opcode); +// Returns the number of HloOpcode values. +inline const uint32_t HloOpcodeCount() { + return static_cast(HloOpcode::kWhile) + 1; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 72911ae9f91c175d729c3136959cf47029e8a695..4c3ff3bdafc0e5184b715b938b317c3ff85fbfa8 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -15,13 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include #include #include -#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -113,6 +110,20 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, // a_ancestor and b_ancestor must be either both null or both non-null. CHECK_NE(b_ancestor, nullptr); CHECK_EQ(a_ancestor->parent(), b_ancestor->parent()); + + // If the common ancestor is a while instruction there is an additional + // ordering criteria which may apply. The condition computation is considered + // to execute before the body computation so if 'a' is in the condition and + // 'b' is in the body, then 'a' executes before 'b'. + if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) { + const HloComputation* body = a_ancestor->while_body(); + const HloComputation* condition = a_ancestor->while_condition(); + if (call_graph_->InstructionIsNestedIn(a, condition) && + call_graph_->InstructionIsNestedIn(b, body)) { + return true; + } + } + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); } @@ -141,7 +152,7 @@ bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( CHECK_EQ(a->parent(), b->parent()); // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. - return strict_predecessors_.at(b->parent())->IsReachable(b, a); + return a != b && predecessors_.at(a->parent())->IsReachable(a, b); } string PredecessorHloOrdering::ToStringHelper(const string& name) const { @@ -153,10 +164,10 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { const auto all = computation->MakeInstructionPostOrder(); for (auto instruction : all) { pieces.push_back(tensorflow::strings::Printf( - " %s strict predecessors:", instruction->name().c_str())); + " %s predecessors:", instruction->name().c_str())); for (auto predecessor : all) { - if (strict_predecessors_.at(computation.get()) - ->IsReachable(instruction, predecessor)) { + if (predecessors_.at(computation.get()) + ->IsReachable(predecessor, instruction)) { pieces.push_back( tensorflow::strings::Printf(" %s", predecessor->name().c_str())); } @@ -172,8 +183,11 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. for (auto& computation : module->computations()) { - strict_predecessors_.emplace(computation.get(), - computation->ComputeTransitiveOperands()); + if (computation->IsFusionComputation()) { + continue; + } + predecessors_.emplace(computation.get(), + computation->ComputeReachability()); } } @@ -238,358 +252,6 @@ string SequentialHloOrdering::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { - return 0; - } - - const HloModule* module = module_sequence.begin()->first->parent(); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. We run the heap simulation on the whole module, - // rather than summing each computation, since it gives us a better lower - // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, - module_sequence, *points_to_analysis, size_function)); - return result.heap_size; -} - -namespace { - -// Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage. -class ListScheduler { - public: - // Construct and return a memory-minimizing sequence of HLO instructions - // containing the given HLO computation. - static StatusOr> Run( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - ListScheduler scheduler(computation, points_to_analysis, size_function); - return scheduler.CreateSchedule(); - } - - private: - // The scheduling priority of an instruction is first the number of bytes - // freed by scheduling the instruction, and second (tie-breaker) by the number - // of users. This is represented as a std::pair containing these two values - // (first element is the bytes freed). std::pair provides the necessary - // comparison operators. - using Priority = std::pair; - - ListScheduler(const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) - : computation_(computation), - points_to_analysis_(points_to_analysis), - size_function_(size_function) { - // Create a map containing the LogicalBuffer uses for each HLO - // instruction. An HLO instruction "uses" a LogicalBuffer if the - // LogicalBuffer is in an operand of the instruction as indicated by - // points-to analysis. - for (auto& instruction : computation.instructions()) { - buffer_uses_.insert( - {instruction.get(), std::unordered_set()}); - for (auto* operand : instruction->operands()) { - for (const LogicalBuffer* buffer : - points_to_analysis.GetBuffersDefinedByInstruction(operand)) { - buffer_uses_[instruction.get()].insert(buffer); - } - } - } - - // Create map containing the number of unscheduled uses (hlo instructions) - // of each logical buffer. - for (auto& instruction : computation.instructions()) { - for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( - instruction.get())) { - unscheduled_use_count_[buffer] = 0; - } - } - for (auto& instruction : computation.instructions()) { - for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { - ++unscheduled_use_count_[buffer]; - } - } - - // Buffers live out of the computation have an implicit use at the end of - // the computation. - for (const LogicalBuffer* live_out_buffer : - points_to_analysis.GetPointsToSet(computation.root_instruction()) - .CreateFlattenedSet()) { - ++unscheduled_use_count_[live_out_buffer]; - } - } - - // Returns whether the memory used by the given buffer should be ignored by - // the scheduling heuristic. - bool IgnoreBuffer(const LogicalBuffer& buffer) { - return buffer.instruction()->opcode() == HloOpcode::kParameter || - buffer.instruction()->opcode() == HloOpcode::kConstant; - } - - // Return the number of bytes freed if the HLO instruction is scheduled. - int64 BytesFreedIfScheduled(const HloInstruction* instruction) { - int64 freed_bytes = 0; - // Sum the total size of the values last used by this instruction. - for (auto* buffer : buffer_uses_.at(instruction)) { - if (IgnoreBuffer(*buffer)) { - continue; - } - CHECK_GE(unscheduled_use_count_.at(buffer), 1); - if (unscheduled_use_count_.at(buffer) == 1) { - // This is the last use of the logical buffer. - freed_bytes += size_function_(*buffer); - } - } - // Then subtract the size of the value(s) defined by this instruction. - for (auto* buffer : - points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { - if (!IgnoreBuffer(*buffer)) { - freed_bytes -= size_function_(*buffer); - } - } - return freed_bytes; - } - - // Construct the scheduling priority of the given instruction. - Priority GetPriority(const HloInstruction* instruction) { - return {BytesFreedIfScheduled(instruction), instruction->user_count()}; - } - - std::vector CreateSchedule() { - std::vector schedule; - - // Populate the ready list with instructions which have no operands or - // control predecessors. - std::unordered_map unscheduled_pred_count; - std::list ready_list; - for (auto& instruction : computation_.instructions()) { - // TODO(b/34466113): Replace this and above with successors() or - // predecessors() when these methods are added to HloInstruction. - for (const HloInstruction* user : instruction->users()) { - unscheduled_pred_count[user]++; - } - for (const HloInstruction* succ : instruction->control_successors()) { - unscheduled_pred_count[succ]++; - } - } - for (auto& instruction : computation_.instructions()) { - // Instruction with no operands or control predecessors will - // not be in the map. - if (unscheduled_pred_count.count(instruction.get()) == 0) { - ready_list.push_back(instruction.get()); - } - } - - while (!ready_list.empty()) { - // Select the highest priority HLO instruction from the ready list. - auto best_it = ready_list.begin(); - Priority best_priority = GetPriority(*best_it); - for (auto ready_it = std::next(ready_list.begin()); - ready_it != ready_list.end(); ++ready_it) { - Priority priority = GetPriority(*ready_it); - if (priority > best_priority) { - best_it = ready_it; - best_priority = priority; - } - } - - // Remove the selected instruction from the ready list and add it to the - // schedule. - const HloInstruction* best = *best_it; - ready_list.erase(best_it); - schedule.push_back(best); - scheduled_instructions_.insert(best); - - // Update the unscheduled uses of the logical buffers. - for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { - CHECK_GT(unscheduled_use_count_.at(buffer), 0); - --unscheduled_use_count_[buffer]; - } - - // Add new instructions to ready list. - auto update_pred_count = [&unscheduled_pred_count, - &ready_list](HloInstruction* inst) { - int64 pred_count = --unscheduled_pred_count.at(inst); - CHECK_GE(pred_count, 0); - if (pred_count == 0) { - ready_list.push_back(inst); - } - }; - // TODO(b/34466113): Replace this and above with successors() or - // predecessors() when these methods are added to HloInstruction. - for (HloInstruction* user : best->users()) { - update_pred_count(user); - } - for (HloInstruction* succ : best->control_successors()) { - update_pred_count(succ); - } - } - CHECK_EQ(schedule.size(), computation_.instructions().size()); - CHECK_EQ(scheduled_instructions_.size(), - computation_.instructions().size()); - - return schedule; - } - - const HloComputation& computation_; - const TuplePointsToAnalysis& points_to_analysis_; - const LogicalBuffer::SizeFunction& size_function_; - - // A map containing the LogicalBuffers that each instruction uses. - std::unordered_map> - buffer_uses_; - - // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. - std::unordered_map unscheduled_use_count_; - - // Set of instructions which have been scheduled. - std::unordered_set scheduled_instructions_; -}; - -int64 SumLogicalBufferSizes(const std::vector& buffers, - const LogicalBuffer::SizeFunction& size_function) { - int64 size = 0; - for (const LogicalBuffer* buffer : buffers) { - size += size_function(*buffer); - } - return size; -} - -StatusOr> RunDFSMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // This ordering is based on DFS post-order, with a heuristic to decide which - // operand to visit first. The heuristic is based on 'extra_users', which is - // simply users-1 for each instruction. By subtracting 1, we're saying that - // instructions with no users or a single user don't count; instructions with - // lots of fan-out will be visited earlier. - tensorflow::gtl::FlatMap extra_users; - tensorflow::gtl::FlatMap total_sizes; - for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { - extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; - total_sizes[hlo] = SumLogicalBufferSizes( - points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); - tensorflow::gtl::FlatSet unique_operands( - hlo->operands().begin(), hlo->operands().end()); - for (const HloInstruction* operand : unique_operands) { - extra_users[hlo] += extra_users[operand]; - total_sizes[hlo] += total_sizes[operand]; - } - } - CHECK_EQ(extra_users.size(), computation.instructions().size()); - CHECK_EQ(total_sizes.size(), computation.instructions().size()); - - // Construct a total order based on DFS post-order, visiting operands in - // decreasing cumulative extra user order, and next by cumulative size, with a - // tiebreaker by name for determinism. - std::vector sequence; - FunctionVisitor visitor([&sequence](HloInstruction* hlo) { - sequence.push_back(hlo); - return Status::OK(); - }); - TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( - &visitor, [&extra_users, &total_sizes](const HloInstruction* a, - const HloInstruction* b) { - if (extra_users[a] != extra_users[b]) { - return extra_users[a] > extra_users[b]; - } - if (total_sizes[a] != total_sizes[b]) { - return total_sizes[a] > total_sizes[b]; - } - return a->name() < b->name(); - })); - CHECK_EQ(sequence.size(), computation.instructions().size()); - return sequence; -} - -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; -} - -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // We try both a list-scheduler based ordering and a DFS based ordering, and - // choose whichever returns a lower min-memory, not accounting for - // fragmentation. - // - // Note that this is just a heuristic. One obvious inaccuracy is that the - // memory required for sub-computations might be different when considered - // within the caller's context. But it's good enough for now. - TF_ASSIGN_OR_RETURN( - std::vector list_sequence, - ListScheduler::Run(computation, points_to_analysis, size_function)); - TF_ASSIGN_OR_RETURN( - const int64 list_memory, - MinimumMemoryForComputation(computation, list_sequence, - points_to_analysis, size_function)); - VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; - - TF_ASSIGN_OR_RETURN( - std::vector dfs_sequence, - RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); - TF_ASSIGN_OR_RETURN( - const int64 dfs_memory, - MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, - size_function)); - VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; - - if (list_memory <= dfs_memory) { - VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; - return list_sequence; - } else { - VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; - return dfs_sequence; - } -} - -} // namespace - -StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(&module)); - for (const auto& computation : module.computations()) { - TF_ASSIGN_OR_RETURN(sequence[computation.get()], - CreateMemoryMinimizingSequence( - *computation, *points_to_analysis, size_function)); - } - return sequence; -} - -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation.parent())); - return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function); -} - std::ostream& operator<<( std::ostream& out, const SequentialHloOrdering::HloModuleSequence& module_sequence) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b59e1ea5eb0ad4882d4c2b96ee6ab6d1bc973993..130431f28070d52c3a76befa0d5272a3cc295711 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -24,12 +24,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -72,8 +68,8 @@ class HloOrdering { std::unique_ptr call_graph_; }; -// Base class for partial orderings implemented by a map of strict predecessors -// for each instruction. Subclasses should fill in strict_predecessors_. +// Base class for partial orderings implemented by a map of predecessors for +// each instruction. Subclasses should fill in predecessors_. class PredecessorHloOrdering : public HloOrdering { public: ~PredecessorHloOrdering() override = default; @@ -93,13 +89,12 @@ class PredecessorHloOrdering : public HloOrdering { const HloInstruction* b) const override; // For each computation in the module, this is the set of the instruction's - // strict predecessors. An instruction is not an element of its own strict - // predecessor set. + // predecessors. An instruction is an element of its own predecessor set. // // Subclasses should fill this in to define the desired ordering. tensorflow::gtl::FlatMap> - strict_predecessors_; + std::unique_ptr> + predecessors_; }; // An HLO ordering based on data dependencies in the HLO graph. In this partial @@ -191,24 +186,6 @@ std::ostream& operator<<( std::ostream& out, const SequentialHloOrdering::HloModuleSequence& module_sequence); -// Returns the minimum memory required to compute the given module sequence, -// assuming no fragmentation. -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns an HloModuleSequence which seeks to minimize the memory required for -// the computation. size_function is the function returning the number of bytes -// required for a LogicalBuffer. -StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function); - -// Overload of above that computes the sequence for a single computation. -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 21d852a51d67b2aadc0edea144f60a037a004614..ad6070a9c1b45afd418c9210a2d1b3def3eaf4d5 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -61,7 +62,7 @@ TEST_F(HloOrderingTest, LastUseScheduledFirst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); @@ -101,7 +102,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { auto builder_c = HloComputation::Builder("C"); HloInstruction* c = builder_c.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); HloComputation* computation_c = module->AddEmbeddedComputation(builder_c.Build()); @@ -155,67 +156,69 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { EXPECT_FALSE(ordering.ExecutesBefore(y, c)); } -class MinimumMemoryForSequenceTest : public HloTestBase {}; - -TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { +TEST_F(HloOrderingTest, InstructionsInWhileComputations) { + // Tests the ordering of instructions in the body and condition of a while + // instruction. HLO code: + // + // body(F32[]) %param): + // %negate = Negate(%param) + // + // condition(F32[] %param): + // %convert = Convert(%param) + // + // entry: + // %constant = Constant(1.0) + // return While(%constant, body, condition) + // auto module = CreateNewModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); - - auto cond_builder = HloComputation::Builder("WhileCond"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); - HloInstruction* cond_iter = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); - HloInstruction* cond_data = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); - // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) - HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); - HloComputation* cond_computation = - module->AddEmbeddedComputation(cond_builder.Build()); - auto body_builder = HloComputation::Builder("WhileBody"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "body_param")); - HloComputation* body_computation = - module->AddEmbeddedComputation(body_builder.Build()); + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "body_param")); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape, HloOpcode::kNegate, body_param)); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "cond_param")); + auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(xla::PRED, {}), cond_param)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); - // Entry params: 8 bytes (4 bytes per param), TOTAL=8 - HloInstruction* iter = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); - HloInstruction* data = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param_data")); - // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 - HloInstruction* tuple = - builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); - // While: 8 bytes (4 bytes per element), TOTAL=32 - // Both cond and body use a max of 24 bytes, TOTAL=56 - HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, cond_computation, body_computation, tuple)); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const LogicalBuffer& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, - MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); + module->AddEntryComputation(builder.Build()); + + DependencyHloOrdering ordering(module.get()); + EXPECT_TRUE(ordering.ExecutesBefore(constant, xla_while)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, cond_param)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, convert)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, negate)); + + // The while should be unordered relative to the body and condition + // instructions. + EXPECT_FALSE(ordering.ExecutesBefore(xla_while, body_param)); + EXPECT_FALSE(ordering.ExecutesBefore(xla_while, cond_param)); + EXPECT_FALSE(ordering.ExecutesBefore(body_param, xla_while)); + EXPECT_FALSE(ordering.ExecutesBefore(cond_param, xla_while)); + + // Condition instructions should be ordered before body instructions. + EXPECT_TRUE(ordering.ExecutesBefore(cond_param, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(convert, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(cond_param, negate)); + EXPECT_TRUE(ordering.ExecutesBefore(convert, negate)); + + EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); } } // namespace - } // namespace xla int main(int argc, char** argv) { diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 119e2d79022dca094147348d83c59b9a04cb339f..4b824f8240074e7ae70b9d9fa82dfa0706d5b355 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -30,9 +31,10 @@ using ::tensorflow::strings::StrAppend; namespace xla { namespace { -void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module, +void DumpModule(const HloModule& module, + const string& message) { - dumper_(module, message); + hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(2) << "HLO " << message << ":"; XLA_VLOG_LINES(2, module.ToString()); } @@ -75,7 +77,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { // Emit label containing: "after foo-pass, before bar-pass". message.clear(); StrAppend(&message, prefix, ", before ", pass->name()); - DumpModule(dumper_, *module, message); + DumpModule(*module, message); TF_RETURN_IF_ERROR(run_invariant_checkers()); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); @@ -85,7 +87,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { StrAppend(&prefix, name(), ": after ", pass->name()); } TF_RETURN_IF_ERROR(run_invariant_checkers()); - DumpModule(dumper_, *module, prefix + ", pipeline end"); + DumpModule(*module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 682c4b952df6aae8cb933c222772dbd823070ecc..a42d7e59fed2d838dfe3cb7f99e6b946edfdb0b4 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -22,7 +22,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,9 +33,7 @@ namespace xla { // Pipeline of HLO passes. class HloPassPipeline : public HloPassInterface { public: - explicit HloPassPipeline(const string& name, - const Compiler::HloDumper& dumper) - : name_(name), dumper_(dumper) {} + explicit HloPassPipeline(const string& name) : name_(name) {} tensorflow::StringPiece name() const override { return name_; } // Add a pass to the pipeline. It should be called with the arguments for the @@ -69,7 +66,6 @@ class HloPassPipeline : public HloPassInterface { private: const string name_; - Compiler::HloDumper dumper_; std::vector> passes_; std::vector> invariant_checkers_; bool run_called_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index a153d73dbd838663c0d7e0d72ad54668f243f2c2..d45038f1f4a2e4aa19234eec93fdc9a068a902e1 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -25,7 +25,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && ShapeUtil::IsScalarF32(instruction->shape())) { - *out = LiteralUtil::Get(instruction->literal(), {}); + *out = instruction->literal().Get({}); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb7ecbdc2a09e6e797d283675ccf2c26f9c1a34c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_reachability.h" + +namespace xla { + +HloReachabilityMap::HloReachabilityMap( + const std::list& instructions) + : size_(instructions.size()) { + bit_vectors_.reserve(size_); + for (const HloInstruction* hlo : instructions) { + indices_[hlo] = bit_vectors_.size(); + bit_vectors_.emplace_back(size_); + } + CHECK_EQ(size_, indices_.size()); // instructions should be unique +} + +bool HloReachabilityMap::SetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction) { + BitVector& bit_vector = GetBitVector(instruction); + tmp_bit_vector_ = bit_vector; + + bit_vector.SetToZero(); + bit_vector.Set(GetIndex(instruction)); + for (const HloInstruction* input : inputs) { + bit_vector.OrWith(GetBitVector(input)); + } + + return bit_vector != tmp_bit_vector_; +} + +void HloReachabilityMap::SetReachable(const HloInstruction* a, + const HloInstruction* b) { + GetBitVector(b).Set(GetIndex(a)); +} + +bool HloReachabilityMap::IsReachable(const HloInstruction* a, + const HloInstruction* b) const { + return GetBitVector(b).Get(GetIndex(a)); +} + +bool HloReachabilityMap::IsConnected(const HloInstruction* a, + const HloInstruction* b) const { + return IsReachable(a, b) || IsReachable(b, a); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h new file mode 100644 index 0000000000000000000000000000000000000000..d7bdac9c86579f19afbba133772c2c50894853d1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ + +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class HloInstruction; + +// A class for computing and representing reachability between HloInstructions. +class HloReachabilityMap { + public: + // Sets up an empty reachable matrix for the full set of instructions + // specified in 'instructions'. + explicit HloReachabilityMap(const std::list& instructions); + + // Set the reachability set of 'instruction' to the union of the reachability + // sets of 'inputs'. Upon return, IsReachable(x, instruction) where + // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true + // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from + // itself. Returns whether the reachability set of 'instruction' changed. + bool SetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction); + + // Sets entry so that IsReachable(a, b) will return true + void SetReachable(const HloInstruction* a, const HloInstruction* b); + + // Returns true if "b" is reachable from "a" + bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; + + // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + + private: + // A bit-vector implementation specialized for this use case which provides a + // fast bitwise OR operation not available in tensorflow::gtl::BitMap. + class BitVector { + public: + BitVector() = default; + BitVector(size_t size) + : size_(size), vector_((size + kBits - 1) / kBits, 0) {} + + // Return the bit at the given index. + bool Get(size_t index) const { + DCHECK(index >= 0 && index < size_); + return vector_[index / kBits] & (1ull << (index % kBits)); + } + + // Set the bit at the given index. + void Set(size_t index) { + DCHECK(index >= 0 && index < size_); + vector_[index / kBits] |= 1ull << (index % kBits); + } + + // Set this bitvector to the Logical OR of this bitvector and 'other'. + void OrWith(const BitVector& other) { + for (size_t i = 0; i < vector_.size(); ++i) { + vector_[i] |= other.vector_[i]; + } + } + + // Set the bitvector to all zeros. + void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); } + + bool operator==(const BitVector& other) const { + return vector_ == other.vector_; + } + bool operator!=(const BitVector& other) const { + return vector_ != other.vector_; + } + + private: + using Word = uint64; + static const size_t kBits = 64; + + // Number of bits in the bitvector. + size_t size_; + + std::vector vector_; + }; + + // Return the bitvector storing the reachability-to of the given instruction. + const BitVector& GetBitVector(const HloInstruction* instruction) const { + return bit_vectors_[GetIndex(instruction)]; + } + BitVector& GetBitVector(const HloInstruction* instruction) { + return bit_vectors_[GetIndex(instruction)]; + } + + // Return the index of the given instruction. The value is used to index into + // the vector of BitVectors and the BitVectors themselves. + int GetIndex(const HloInstruction* instruction) const { + return FindOrDie(indices_, instruction); + } + + // The number of instructions in the reachability map. + const size_t size_; + + // Dense assignment from HloInstruction* to number. These numbers index + // into the bit_vectors_ vector and into the bits within a BitVector. + tensorflow::gtl::FlatMap indices_; + + // Bitvectors holding the reachability to each instruction. The bit vector for + // instruction X includes ones for each instruction which X is reachable from. + std::vector bit_vectors_; + + // A temporary used by SetReachabilityToUnion to avoid an allocation with each + // call to the method. + BitVector tmp_bit_vector_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..657a9ee83d29e72b95660325f9139f44159d6508 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_reachability.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { + +namespace { + +class HloReachabilityTest : public HloTestBase {}; + +TEST_F(HloReachabilityTest, Reachability) { + // Construct and test a reachability graph of the following form: + /* + a + / \ + b c + \ / \ + d e + */ + auto builder = HloComputation::Builder(TestName()); + auto a = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + auto b = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + auto c = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + auto d = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + auto e = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + builder.Build(); + + HloReachabilityMap reachability({a, b, c, d, e}); + reachability.SetReachable(a, a); + EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, b)); + EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, c)); + EXPECT_TRUE(reachability.SetReachabilityToUnion({b, c}, d)); + EXPECT_TRUE(reachability.SetReachabilityToUnion({c}, e)); + + EXPECT_TRUE(reachability.IsReachable(a, a)); + EXPECT_TRUE(reachability.IsReachable(a, b)); + EXPECT_TRUE(reachability.IsReachable(a, c)); + EXPECT_TRUE(reachability.IsReachable(a, d)); + EXPECT_TRUE(reachability.IsReachable(a, e)); + + EXPECT_FALSE(reachability.IsReachable(b, a)); + EXPECT_TRUE(reachability.IsReachable(b, b)); + EXPECT_FALSE(reachability.IsReachable(b, c)); + EXPECT_TRUE(reachability.IsReachable(b, d)); + EXPECT_FALSE(reachability.IsReachable(b, e)); + + EXPECT_FALSE(reachability.IsReachable(e, a)); + EXPECT_FALSE(reachability.IsReachable(e, b)); + EXPECT_FALSE(reachability.IsReachable(e, c)); + EXPECT_FALSE(reachability.IsReachable(e, d)); + EXPECT_TRUE(reachability.IsReachable(e, e)); + + // Recomputing the same reachability for a previously computed instruction + // should return false (no change). + EXPECT_FALSE(reachability.SetReachabilityToUnion({a}, b)); + EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d)); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 2c1b0fff4e602a172cfa54d4eaa626198a426873..fd08796e50383ab9ad1aff4f19e8c67fd72a9a63 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -58,9 +59,8 @@ bool IsRematerializable(const HloInstruction* instruction) { return false; } - // Don't rematerialize instructions with side effects, those with a cost that - // might not be captured by HloCostAnalysis, or instructions which cannot be - // cloned safely. + // Don't rematerialize instructions with side effects or instructions which + // cannot be cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: @@ -802,23 +802,14 @@ bool MemoryUsageTracker::Check() const { // Computes and returns the cost of rematerializing the given instruction. // Cost per rematerialized instruction is defined as: // -// (flop_count + transcendental_count + element_count) / memory_reduced +// memory_limit_bytes / memory_reduced // -// flop_count: from HloCostAnalysis -// transcendental_count: from HloCostAnalysis -// element_count: number of elements accessed in operands and output of -// instruction -// memory_reduced: The memory usage reduced by rematerializing the -// instruction. -// -// This is a rough estimate of the extra execution time per byte saved by -// rematerializing this instruction for its remaining uses. In general, we -// want the most memory saving for the least latency penalty which is captured -// by this heuristic. +// The idea is to choose the operation that will save the most memory for +// rematerialization and do not worry about how much the compute costs since +// running out of memory is more harmful than taking longer to get the answer. int64 RematerializationCost(const HloInstruction* instruction, const MemoryUsageTracker& memory_tracker, - const HloCostAnalysis& cost_analysis, - int64 memory_reduced) { + int64 memory_reduced, int64 memory_limit_bytes) { // If none of the users of 'instruction' have been placed in the sequence (as // tracked by memory_tracker), then rematerialization of 'instruction' is a // zero-cost move of 'instruction' in the sequence. @@ -830,22 +821,8 @@ int64 RematerializationCost(const HloInstruction* instruction, } CHECK_GT(memory_reduced, 0); - const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); - const int64 elements_accessed = - ShapeUtil::IsTuple(instruction->shape()) - ? bytes_accessed - : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType( - instruction->shape().element_type()); - - // Multiply by 256 to improve precision of cost. Without this factor, - // many instructions such as many elementwise instructions would have - // zero cost because the bytes reduced can be several times greater than - // the element count. - return 256 * - (cost_analysis.flop_count(*instruction) + - cost_analysis.transcendental_count(*instruction) + - elements_accessed) / - memory_reduced; + // Return the inverse of the benefit of rematerialization. + return memory_limit_bytes / memory_reduced; } // Selects and returns the best candidate instruction for rematerialization. @@ -856,8 +833,8 @@ int64 RematerializationCost(const HloInstruction* instruction, HloInstruction* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, - const HloCostAnalysis& cost_analysis, - const tensorflow::gtl::FlatSet& blacklist) { + const tensorflow::gtl::FlatSet& blacklist, + int64 memory_limit_bytes) { HloInstruction* best = nullptr; int64 best_cost = 0; @@ -891,12 +868,12 @@ HloInstruction* PickRematerializationCandidate( if (memory_reduced <= 0) { VLOG(5) << "candidate " << candidate->name() - << " memory reduced = " << memory_reduced << " <= 0"; + << " memory reduced = " << memory_reduced << " <= 0"; continue; } const int cost = RematerializationCost(candidate, memory_tracker, - cost_analysis, memory_reduced); + memory_reduced, memory_limit_bytes); VLOG(5) << "candidate " << candidate->name() << ", memory reduced " << memory_reduced << ", cost per byte " << cost; @@ -1011,7 +988,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); HloInstruction* best = PickRematerializationCandidate( - memory_tracker, instruction_list, cost_analysis_, blacklist); + memory_tracker, instruction_list, blacklist, memory_limit_bytes); if (best == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1211,11 +1188,6 @@ StatusOr HloRematerialization::Run( VLOG(1) << "Peak memory usage of module (before): " << HumanReadableNumBytes(before_peak_memory); - // Run cost analysis. Operation cost is used in the heuristic for selecting - // instructions for rematerialization. - TF_RETURN_IF_ERROR( - module->entry_computation()->root_instruction()->Accept(&cost_analysis_)); - // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( @@ -1230,6 +1202,9 @@ StatusOr HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. for (const auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } if (sequence->at(computation.get()).size() != computation->instruction_count()) { // A size mismatch between the computation instruction count and the size diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 1693f93183bc59c343e3c765cb4051566d4377ef..42c279d440b78d90b9f19b92155c52787156e4b7 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -18,7 +18,6 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -61,7 +60,7 @@ class HloRematerialization { protected: HloRematerialization(const ShapeSizeFunction& size_function) - : size_function_(size_function), cost_analysis_(size_function_) {} + : size_function_(size_function) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -100,9 +99,6 @@ class HloRematerialization { // Call graph of the hlo_module. std::unique_ptr call_graph_; - // Analysis used for computing the rematerialization cost of instructions. - HloCostAnalysis cost_analysis_; - // The peak memory usage of each computation. The map contains only those // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index f306bcc309c6c5e57a311496ee0370741de8a6ab..2358969f38ee66e3eb024215cba4c62da3d6a32f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -126,7 +126,7 @@ class HloRematerializationTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); return builder.Build(); } @@ -158,7 +158,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -191,7 +191,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -215,7 +215,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -232,7 +232,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -254,7 +254,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -268,7 +268,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(body_computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -289,7 +289,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -310,7 +310,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -357,7 +357,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { /*dimension=*/0)); builder.AddInstruction(HloInstruction::CreateSlice( vec1024_shape_, concat, /*start_indices=*/{0}, - /*limit_indices=*/{1024}, /*slices=*/{1})); + /*limit_indices=*/{1024}, /*strides=*/{1})); subcomputation = module->AddEmbeddedComputation(builder.Build()); } @@ -406,7 +406,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -473,7 +473,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { /*dimension=*/0)); builder.AddInstruction(HloInstruction::CreateSlice( vec1024_shape_, concat, /*start_indices=*/{0}, - /*limit_indices=*/{1024}, /*slices=*/{1})); + /*limit_indices=*/{1024}, /*strides=*/{1})); subcomputation = module->AddEmbeddedComputation(builder.Build()); } @@ -503,7 +503,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc new file mode 100644 index 0000000000000000000000000000000000000000..922236ee1e79c65719f128c598a5de65d7fc1ab7 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -0,0 +1,423 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr MinimumMemoryForSequence( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; + } + + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; +} + +namespace { + +// Class implementing a list scheduler of HLO instructions which produces a +// sequence which minimizes memory usage. +class ListScheduler { + public: + // Construct and return a memory-minimizing sequence of HLO instructions + // containing the given HLO computation. + static StatusOr> Run( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + ListScheduler scheduler(computation, points_to_analysis, size_function); + return scheduler.CreateSchedule(); + } + + private: + // The scheduling priority of an instruction is first the number of bytes + // freed by scheduling the instruction, and second (tie-breaker) by the number + // of users. This is represented as a std::pair containing these two values + // (first element is the bytes freed). std::pair provides the necessary + // comparison operators. + using Priority = std::pair; + + ListScheduler(const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) + : computation_(computation), + points_to_analysis_(points_to_analysis), + size_function_(size_function) { + // Create a map containing the LogicalBuffer uses for each HLO + // instruction. An HLO instruction "uses" a LogicalBuffer if the + // LogicalBuffer is in an operand of the instruction as indicated by + // points-to analysis. + for (auto& instruction : computation.instructions()) { + std::unordered_set instr_uses; + for (auto* operand : instruction->operands()) { + for (const LogicalBuffer* buffer : + points_to_analysis.GetBuffersDefinedByInstruction(operand)) { + instr_uses.insert(buffer); + } + } + buffer_uses_[instruction.get()] = std::vector( + instr_uses.begin(), instr_uses.end()); + } + + // Create map containing the number of unscheduled uses (hlo instructions) + // of each logical buffer. + for (auto& instruction : computation.instructions()) { + for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( + instruction.get())) { + unscheduled_use_count_[buffer] = 0; + } + } + for (auto& instruction : computation.instructions()) { + for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { + ++unscheduled_use_count_[buffer]; + } + } + + // Buffers live out of the computation have an implicit use at the end of + // the computation. + for (const LogicalBuffer* live_out_buffer : + points_to_analysis.GetPointsToSet(computation.root_instruction()) + .CreateFlattenedSet()) { + ++unscheduled_use_count_[live_out_buffer]; + } + } + + // Returns whether the memory used by the given buffer should be ignored by + // the scheduling heuristic. + bool IgnoreBuffer(const LogicalBuffer& buffer) { + return buffer.instruction()->opcode() == HloOpcode::kParameter || + buffer.instruction()->opcode() == HloOpcode::kConstant; + } + + // An entry in the worklist used by CreateSchedule. Corresponds to one + // HloInstruction, plus some cached metadata, saved for the purposes of making + // BytesFreedIfScheduled fast. + struct ReadyListEntry { + const HloInstruction* instruction; + + // The total size of all buffers defined by this instruction. + int64 bytes_defined; + + // For each buffer B used by this instruction, we keep a pair (B, U), where + // U is the number of uses of B that have not yet been scheduled. This pair + // is a pointer into the unscheduled_use_count_ map, so it gets updated for + // free when we update counts in the map. + std::vector*> + used_buffer_unscheduled_use_counts; + }; + + // Creates a ReadyListEntry for the given instruction. + ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { + ReadyListEntry entry; + entry.instruction = instruction; + + entry.bytes_defined = 0; + for (auto* buffer : + points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { + if (!IgnoreBuffer(*buffer)) { + entry.bytes_defined += size_function_(*buffer); + } + } + + for (auto* buffer : buffer_uses_.at(instruction)) { + if (IgnoreBuffer(*buffer)) { + continue; + } + auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer); + CHECK(unscheduled_use_count_it != unscheduled_use_count_.end()); + entry.used_buffer_unscheduled_use_counts.push_back( + &*unscheduled_use_count_it); + } + return entry; + } + + // Returns the number of bytes freed if the HLO instruction is scheduled. + int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { + int64 freed_bytes = 0; + for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { + auto buffer = kv->first; + auto use_count = kv->second; + if (use_count == 1) { + freed_bytes += size_function_(*buffer); + } + } + return freed_bytes - entry.bytes_defined; + } + + // Constructs the scheduling priority of the given instruction. + Priority GetPriority(const ReadyListEntry& entry) { + return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; + } + + std::vector CreateSchedule() { + std::vector schedule; + + // Populate the ready list with instructions which have no operands or + // control predecessors. + std::unordered_map unscheduled_pred_count; + for (auto& instruction : computation_.instructions()) { + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (const HloInstruction* user : instruction->users()) { + unscheduled_pred_count[user]++; + } + for (const HloInstruction* succ : instruction->control_successors()) { + unscheduled_pred_count[succ]++; + } + } + + std::list ready_list; + for (auto& instruction : computation_.instructions()) { + // Instruction with no operands or control predecessors will + // not be in the map. + if (unscheduled_pred_count.count(instruction.get()) == 0) { + ready_list.push_back(MakeReadyListEntry(instruction.get())); + } + } + + while (!ready_list.empty()) { + // Select the highest priority HLO instruction from the ready list. + auto best_it = ready_list.begin(); + Priority best_priority = GetPriority(*best_it); + for (auto ready_it = std::next(ready_list.begin()); + ready_it != ready_list.end(); ++ready_it) { + Priority priority = GetPriority(*ready_it); + if (priority > best_priority) { + best_it = ready_it; + best_priority = priority; + } + } + + // Remove the selected instruction from the ready list and add it to the + // schedule. + const HloInstruction* best = best_it->instruction; + ready_list.erase(best_it); + schedule.push_back(best); + scheduled_instructions_.insert(best); + + // Update the unscheduled uses of the logical buffers. + for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { + CHECK_GT(unscheduled_use_count_.at(buffer), 0); + --unscheduled_use_count_[buffer]; + } + + // Add new instructions to ready list. + auto update_pred_count = [&](HloInstruction* inst) { + int64 pred_count = --unscheduled_pred_count.at(inst); + CHECK_GE(pred_count, 0); + if (pred_count == 0) { + ready_list.push_back(MakeReadyListEntry(inst)); + } + }; + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (HloInstruction* user : best->users()) { + update_pred_count(user); + } + for (HloInstruction* succ : best->control_successors()) { + update_pred_count(succ); + } + } + CHECK_EQ(schedule.size(), computation_.instructions().size()); + CHECK_EQ(scheduled_instructions_.size(), + computation_.instructions().size()); + + return schedule; + } + + const HloComputation& computation_; + const TuplePointsToAnalysis& points_to_analysis_; + const LogicalBuffer::SizeFunction& size_function_; + + // A map containing the LogicalBuffers that each instruction uses. + std::unordered_map> + buffer_uses_; + + // A map containing the count of unscheduled HLOs which using a particular + // LogicalBuffer. We rely on iterator stability in this map. + std::unordered_map unscheduled_use_count_; + + // Set of instructions which have been scheduled. + std::unordered_set scheduled_instructions_; +}; + +int64 SumLogicalBufferSizes(const std::vector& buffers, + const LogicalBuffer::SizeFunction& size_function) { + int64 size = 0; + for (const LogicalBuffer* buffer : buffers) { + size += size_function(*buffer); + } + return size; +} + +StatusOr> RunDFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + // This ordering is based on DFS post-order, with a heuristic to decide which + // operand to visit first. The heuristic is based on 'extra_users', which is + // simply users-1 for each instruction. By subtracting 1, we're saying that + // instructions with no users or a single user don't count; instructions with + // lots of fan-out will be visited earlier. + tensorflow::gtl::FlatMap extra_users; + tensorflow::gtl::FlatMap total_sizes; + for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; + total_sizes[hlo] = SumLogicalBufferSizes( + points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); + tensorflow::gtl::FlatSet unique_operands( + hlo->operands().begin(), hlo->operands().end()); + for (const HloInstruction* operand : unique_operands) { + extra_users[hlo] += extra_users[operand]; + total_sizes[hlo] += total_sizes[operand]; + } + } + CHECK_EQ(extra_users.size(), computation.instructions().size()); + CHECK_EQ(total_sizes.size(), computation.instructions().size()); + + // Construct a total order based on DFS post-order, visiting operands in + // decreasing cumulative extra user order, and next by cumulative size, with a + // tiebreaker by name for determinism. + std::vector sequence; + FunctionVisitor visitor([&sequence](HloInstruction* hlo) { + sequence.push_back(hlo); + return Status::OK(); + }); + TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + &visitor, [&extra_users, &total_sizes](const HloInstruction* a, + const HloInstruction* b) { + if (extra_users[a] != extra_users[b]) { + return extra_users[a] > extra_users[b]; + } + if (total_sizes[a] != total_sizes[b]) { + return total_sizes[a] > total_sizes[b]; + } + return a->name() < b->name(); + })); + CHECK_EQ(sequence.size(), computation.instructions().size()); + return sequence; +} + +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + // We try both a list-scheduler based ordering and a DFS based ordering, and + // choose whichever returns a lower min-memory, not accounting for + // fragmentation. + // + // Note that this is just a heuristic. One obvious inaccuracy is that the + // memory required for sub-computations might be different when considered + // within the caller's context. But it's good enough for now. + TF_ASSIGN_OR_RETURN( + std::vector list_sequence, + ListScheduler::Run(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 list_memory, + MinimumMemoryForComputation(computation, list_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + + TF_ASSIGN_OR_RETURN( + std::vector dfs_sequence, + RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 dfs_memory, + MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, + size_function)); + VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + + if (list_memory <= dfs_memory) { + VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + return list_sequence; + } else { + VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + return dfs_sequence; + } +} + +} // namespace + +StatusOr +CreateMemoryMinimizingSequence( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { + SequentialHloOrdering::HloModuleSequence sequence; + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(&module)); + for (const auto& computation : module.computations()) { + if (computation->IsFusionComputation()) { + continue; + } + TF_ASSIGN_OR_RETURN(sequence[computation.get()], + CreateMemoryMinimizingSequence( + *computation, *points_to_analysis, size_function)); + } + return sequence; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function) { + CHECK(!computation.IsFusionComputation()); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(computation.parent())); + return CreateMemoryMinimizingSequence(computation, *points_to_analysis, + size_function); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h new file mode 100644 index 0000000000000000000000000000000000000000..ec92a56b962152b15981f868369683144aa7c76a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -0,0 +1,50 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// Returns the minimum memory required to compute the given module sequence, +// assuming no fragmentation. +StatusOr MinimumMemoryForSequence( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + +// Returns an HloModuleSequence which seeks to minimize the memory required for +// the computation. size_function is the function returning the number of bytes +// required for a LogicalBuffer. +StatusOr +CreateMemoryMinimizingSequence( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function); + +// Overload of above that computes the sequence for a single computation. +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d09d22ee40638c5beed3f4eaf3723be0f6b6bf96 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, + MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 867ebc7f61aab1483622d1560d951c053e95f135..e3d287d4c91708577b712261842b6ae231fb188b 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -75,7 +75,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant}, callee1)); auto y = builder.AddInstruction( @@ -89,12 +89,14 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", false, false, nullptr); + "before unification", + module->config().debug_options()); } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", false, false, nullptr); + "after unification", + module->config().debug_options()); } EXPECT_EQ(2, module->computations().size()); EXPECT_EQ(x->to_apply(), y->to_apply()); @@ -110,9 +112,9 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); + HloInstruction::CreateConstant(Literal::CreateR0(3))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1)); auto y = builder.AddInstruction( @@ -126,12 +128,14 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", false, false, nullptr); + "before unification", + module->config().debug_options()); } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", false, false, nullptr); + "after unification", + module->config().debug_options()); } EXPECT_EQ(2, module->computations().size()); EXPECT_EQ(x->to_apply(), y->to_apply()); @@ -164,12 +168,14 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", false, false, nullptr); + "before unification", + module->config().debug_options()); } EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", false, false, nullptr); + "after unification", + module->config().debug_options()); } EXPECT_EQ(3, module->computations().size()); EXPECT_NE(x->to_apply(), y->to_apply()); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 6707b02c5c57262b0154ae6b23fdd61a198a8d70..76177462aa4959261483045296d2388acabe46a5 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -171,8 +171,7 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, break; case HloOpcode::kConstant: if (ShapeUtil::IsScalar(instruction->shape())) { - attrs["value"].set_s( - LiteralUtil::GetAsString(instruction->literal(), {})); + attrs["value"].set_s(instruction->literal().GetAsString({})); } break; case HloOpcode::kCustomCall: diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index c2718ea8003c9d2a8e3d65773b439aae915a30d0..8e9d93e367e51cb69f0a38ae7aa8d9539e78ad8a 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -91,7 +91,7 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { auto builder = HloComputation::Builder("Const"); HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); + HloInstruction::CreateConstant(Literal::CreateR0(123))); OpMetadata metadata; metadata.set_op_name("x"); metadata.set_op_type("y"); diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc new file mode 100644 index 0000000000000000000000000000000000000000..221f67b0c1cd280d88c408f69deab12ed51a8b93 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -0,0 +1,327 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_value.h" + +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +const Shape& HloPosition::shape() const { + return ShapeUtil::GetSubshape(instruction->shape(), index); +} + +string HloPosition::ToString() const { + string index_str = + ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; + return StrCat(instruction->name(), index_str); +} + +std::ostream& operator<<(std::ostream& out, const HloPosition& position) { + out << position.ToString(); + return out; +} + +string HloUse::ToString() const { + string index_str = + ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) + ? (" " + operand_index.ToString()) + : ""; + return StrCat(instruction->name(), ", operand ", operand_number, index_str); +} + +std::ostream& operator<<(std::ostream& out, const HloUse& use) { + out << use.ToString(); + return out; +} + +HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, + const ShapeIndex& index, bool is_phi) + : id_(id), is_phi_(is_phi) { + // The defining position is always the first element in the positions_ vector. + AddPosition(instruction, index); +} + +bool HloValue::operator==(const HloValue& other) const { + bool equal = defining_instruction() == other.defining_instruction() && + defining_index() == other.defining_index(); + // If the values are equal they most both be phi (or non phi). + CHECK(!(equal && is_phi() != other.is_phi())); + return equal; +} + +bool HloValue::operator!=(const HloValue& other) const { + return !(*this == other); +} + +string HloValue::ToShortString() const { + string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) + ? defining_index().ToString() + : ""; + return StrCat(id_, " ", is_phi_ ? "PHI " : "", defining_instruction()->name(), + index_str); +} + +string HloValue::ToString(int indent) const { + string indentation(indent, ' '); + string out = StrCat(indentation, ToShortString(), ", positions:\n"); + for (const HloPosition& position : positions()) { + StrAppend(&out, indentation, " ", position.ToString(), "\n"); + } + StrAppend(&out, indentation, " uses:\n"); + for (const HloUse& use : uses()) { + StrAppend(&out, indentation, " ", use.ToString(), "\n"); + } + return out; +} + +namespace { + +// Returns true if the instruction 'user' may use the value at the given +// ShapeIndex in the given operand. Generally, instruction which pass through +// values transparently without reading the value are not considered to use the +// value. +bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, + const HloInstruction* user) { + switch (user->opcode()) { + case HloOpcode::kGetTupleElement: + case HloOpcode::kCopy: + // These instructions only access the top-level values of their + // operand. Non-top-level (nested) values are passed through + // transparently. + CHECK_EQ(operand_number, 0); + return index.empty(); + case HloOpcode::kSelect: + // Select does not use any nested elements of its selected-from operands + // (operand 1 and 2) + CHECK_GE(operand_number, 0); + CHECK_LE(operand_number, 2); + return operand_number == 0 || index.empty(); + + case HloOpcode::kCall: + case HloOpcode::kTuple: + // These instructions always pass through their operands transparently. + return false; + + case HloOpcode::kWhile: + // Though the while instructions passes through its operands, we return + // true because in SSA form there may be a Phi at the parameter of the + // while which is considered a use of its incoming value because the Phi + // input values are not passed through into the body computation. Because + // this function is used in both SSA and non-SSA forms of the analysis + // conservatively return true. + return true; + + default: + return true; + } +} + +} // namespace + +void HloValue::AddPosition(HloInstruction* instruction, + const ShapeIndex& index) { + HloPosition new_position{instruction, index}; + + // The new position must not already exist in positions_. + for (const HloPosition& position : positions_) { + DCHECK_NE(position, new_position); + } + // The shape of the new position must match existing positions. + if (!positions_.empty()) { + CHECK( + ShapeUtil::Compatible(positions_.front().shape(), new_position.shape())) + << "front: " << positions_.front() << " new: " << new_position; + } + + positions_.push_back(std::move(new_position)); + + // Update uses. + for (HloInstruction* user : instruction->users()) { + for (int64 operand_number : user->OperandIndices(instruction)) { + if (MayUseOperandValue(operand_number, index, user)) { + HloUse new_use{user, operand_number, index}; + + // The new use must not already exist in uses_. + for (const HloUse& use : uses_) { + DCHECK_NE(use, new_use); + } + + uses_.push_back(std::move(new_use)); + } + } + } + + // Update liveout status of this HloValue. + const HloModule& module = *instruction->parent()->parent(); + if (instruction == module.entry_computation()->root_instruction()) { + live_out_of_module_ = true; + } + + if (instruction == instruction->parent()->root_instruction()) { + live_out_of_computation_ = true; + } +} + +void HloValue::RemovePosition(HloInstruction* instruction, + const ShapeIndex& index) { + // The defining position cannot be removed. + CHECK(!(instruction == defining_instruction() && index == defining_index())); + + int64 size_before = positions_.size(); + positions_.erase( + std::remove_if(positions_.begin(), positions_.end(), + [instruction, &index](const HloPosition& position) { + return position.instruction == instruction && + position.index == index; + }), + positions_.end()); + // Only a single position should have been removed. + CHECK_EQ(positions_.size(), size_before - 1); + + // Update uses which referred to this position. + uses_.erase(std::remove_if(uses_.begin(), uses_.end(), + [instruction, &index](const HloUse& use) { + return use.instruction->operand( + use.operand_number) == instruction && + use.operand_index == index; + }), + uses_.end()); + + // Returns whether this value is contained in the given instruction's output. + auto is_contained_in = [this](const HloInstruction* instruction) { + for (const HloPosition& position : positions()) { + if (position.instruction == instruction) { + return true; + } + } + return false; + }; + + const HloModule& module = *instruction->parent()->parent(); + if (instruction == module.entry_computation()->root_instruction()) { + // Value has been removed from a position in the entry root instruction. + live_out_of_module_ = + is_contained_in(module.entry_computation()->root_instruction()); + } + if (instruction == defining_instruction()->parent()->root_instruction()) { + // Value has been removed from the root of the computation the value has + // been defined in. + live_out_of_computation_ = + is_contained_in(defining_instruction()->parent()->root_instruction()); + } +} + +std::ostream& operator<<(std::ostream& out, const HloValue& value) { + out << value.ToShortString(); + return out; +} + +void HloValueSet::SortAndUniquifyValues() { + std::sort(values_.begin(), values_.end(), HloValue::IdLessThan); + values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), + values_.end()); +} + +string HloValueSet::ToString() const { + return StrCat("HloValueSet: ", + Join(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); +} + +/*static */ +HloValueSet HloValueSet::Union( + tensorflow::gtl::ArraySlice inputs) { + HloValueSet union_set; + for (const HloValueSet* input : inputs) { + for (const HloValue* value : input->values()) { + union_set.values_.push_back(value); + } + } + union_set.SortAndUniquifyValues(); + return union_set; +} + +bool HloValueSet::AddValue(const HloValue* value) { + auto it = std::lower_bound(values_.begin(), values_.end(), value, + HloValue::IdLessThan); + if (it == values_.end() || (*it)->id() != value->id()) { + values_.insert(it, value); + return true; + } + return false; // already exists +} + +std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { + out << value_set.ToString(); + return out; +} + +InstructionValueSet InstructionValueSet::Union( + tensorflow::gtl::ArraySlice inputs) { + CHECK_GT(inputs.size(), 0); + for (int i = 1; i < inputs.size(); ++i) { + CHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); + } + InstructionValueSet union_set(inputs[0]->shape()); + union_set.ForEachMutableElement( + [&inputs](const ShapeIndex& index, HloValueSet* value_set) { + std::vector input_sets; + for (const InstructionValueSet* input : inputs) { + input_sets.push_back(&input->element(index)); + } + *value_set = HloValueSet::Union(input_sets); + }); + return union_set; +} + +std::ostream& operator<<(std::ostream& out, + const InstructionValueSet& instruction_value_set) { + out << instruction_value_set.ToString(); + return out; +} + +string InstructionValueSet::ToString() const { + string out = + StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); + ForEachElement([this, &out](const ShapeIndex& index, + const HloValueSet& value_set) { + StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); + }); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h new file mode 100644 index 0000000000000000000000000000000000000000..a21e34821748e5077ba19c29057d85f7c12088c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -0,0 +1,267 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Abstraction which identifies a specific point in the XLA graph. An +// HloPosition specifies a ShapeIndex within the output of a specific +// instruction. +struct HloPosition { + HloInstruction* instruction; + ShapeIndex index; + + // Returns the shape at this position. + const Shape& shape() const; + + string ToString() const; + + bool operator==(const HloPosition& other) const { + return instruction == other.instruction && index == other.index; + } + bool operator!=(const HloPosition& other) const { return !(*this == other); } +}; + +std::ostream& operator<<(std::ostream& out, const HloPosition& position); + +// Defines a single use of an HLO value. +struct HloUse { + // Instruction at which the value is used. + HloInstruction* instruction; + + // The operand number in which the value is appears. + int64 operand_number; + + // The shape index within the operand in which the value appears. + ShapeIndex operand_index; + + string ToString() const; + + bool operator==(const HloUse& other) const { + return instruction == other.instruction && + operand_number == other.operand_number && + operand_index == other.operand_index; + } + + bool operator!=(const HloUse& other) const { return !(*this == other); } +}; + +std::ostream& operator<<(std::ostream& out, const HloUse& use); + +// Class describing a value used by the dataflow analysis. XLA arrays are +// trivially a single HloValue. Tuples are made up of more than one HloValue: an +// HloValue for the pointer vector, and an HloValue for each child element. +// +// Every HloValue is defined by a particular instruction and most instructions +// define only a single HloValue. Instructions which define a single HloValue +// include array-shaped instructions such as Add but also includes Tuple-shaped +// instructions such as Tuple. The Tuple instruction defines a single HloValue +// which is a vector of pointers to the values containing the Tuple +// instruction's operands. Though the result of the Tuple instruction includes +// multiple values only the top-level HloValue (the vector of pointers) is +// defined by the Tuple instruction. The values containing the tuple elements +// are defined by earlier instructions, usually the operands of the Tuple +// instruction. +// +// Instructions which construct both the tuple *and* the tuple elements define +// more than one HloValue. This includes (at least) tuple-shaped Constant, +// Parameter, Infeed and While instructions. These tuple-shaped instructions do +// not assemble a tuple from existing HloValues like the Tuple instruction does, +// but rather define all the HloValues in the tuple. +class HloValue { + public: + using Id = int64; + + // Predicate comparing HloValues by increasing id, useful for std::sort. + static bool IdLessThan(const HloValue* a, const HloValue* b) { + return a->id() < b->id(); + } + + // Predicate comparing HloValues by equal id, useful for std::unique. + static bool IdEqual(const HloValue* a, const HloValue* b) { + return a->id() == b->id(); + } + + // Construct an HloValue defined by 'instruction' at shape index 'index'. If + // is_phi is true, then this value is a phi value, for example, at the + // parameter of a while body computation. Phi values are only used in the SSA + // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). + HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, + bool is_phi = false); + + // Return a unique identifier for this HloValue. This value is used for stable + // sorting and iteration + Id id() const { return id_; } + + // Returns whether this value is a phi value. + bool is_phi() const { return is_phi_; } + + // Return the position where this value is defined. + const HloPosition& defining_position() const { return positions_[0]; } + + // Return the instruction which defines this HloValue. + HloInstruction* defining_instruction() const { + return defining_position().instruction; + } + + // Return the shape index at which this HloValue is defined in the output of + // its defining instruction. + const ShapeIndex& defining_index() const { return defining_position().index; } + + // Return the shape of this HloValue. + const Shape& shape() const { return defining_position().shape(); } + + // Add or remove a position at which the HloValue appears. The definition + // position can not be removed. The uses of the HloValue are updated. + void AddPosition(HloInstruction* instruction, const ShapeIndex& index); + void RemovePosition(HloInstruction* instruction, const ShapeIndex& index); + + // Return all positions of the HloValue in the module. + const std::vector& positions() const { return positions_; } + + // Return all uses of the HloValue. + const std::vector& uses() const { return uses_; } + + // Get whether this HloValue is live out of the module. + bool live_out_of_module() const { return live_out_of_module_; } + + // Get whether this HloValue is live out of the computation it is defined in. + bool live_out_of_computation() const { return live_out_of_computation_; } + + bool operator==(const HloValue& other) const; + bool operator!=(const HloValue& other) const; + + // Return a single-line string representation of the value. + string ToShortString() const; + + string ToString(int indent = 0) const; + + private: + // Unique identifier for this HloValue. Used for stable sorting and iteration. + const Id id_; + + // Whether this instruction is a phi value. + const bool is_phi_; + + // The set of positions of this HloValue. The first element is always the + // position of the definition. + std::vector positions_; + + // The set of uses of this HloValue. + std::vector uses_; + + // Whether this value is live out of the HLO module. + bool live_out_of_module_ = false; + + // Whether this value is live out of its computation. + bool live_out_of_computation_ = false; +}; + +std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); + +// A class representing the possible set of HloValues at a particular point +// (shape index in the output of an instruction) in the XLA graph. This set +// contains the set of reaching HloValue definitions. For a simple array-shaped +// instruction like Add, the HloValueSet of the top-level of the instruction's +// output trivially contains only the HloValue defined by the instruction. For +// instructions which have non-trivial dataflow such as Tuple or Select, the +// HloValueSets of the instruction's output contains one or more HloValues +// defined by the instruction's operands or defined further up in the XLA graph. +class HloValueSet { + public: + HloValueSet() = default; + + explicit HloValueSet(tensorflow::gtl::ArraySlice values) + : values_(values.begin(), values.end()) { + SortAndUniquifyValues(); + } + + // Return the union of the given HloValueSets. + static HloValueSet Union( + tensorflow::gtl::ArraySlice inputs); + + // Return the vector of HloValues in the set. Values in the vector are unique + // and sorted. + const std::vector& values() const { return values_; } + + // Adds the value to the set. Returns true iff the value was added and didn't + // already exist in the set. + bool AddValue(const HloValue* value); + + // Return the unique HLO value in the set. CHECKs if the set does not contain + // exactly one value. + const HloValue& GetUniqueValue() const { + CHECK_EQ(values_.size(), 1); + return *values_[0]; + } + + bool operator==(const HloValueSet& other) const { + if (values_.size() != other.values_.size()) return false; + for (size_t i = 0; i < values_.size(); ++i) { + if (values_[i]->id() != other.values_[i]->id()) { + return false; + } + } + return true; + } + bool operator!=(const HloValueSet& other) const { return !(*this == other); } + + string ToString() const; + + private: + // Sorts value_ and removes duplicates. This should be called after adding any + // elements to values_. + void SortAndUniquifyValues(); + + // HloValues sorted by HloValue::Id. + std::vector values_; +}; + +std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); + +// A class collecting the HloValues which might be contained in the output of +// an HLO instruction. For array-shaped instructions, an InstructionValueSet +// trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets +// hold multiple HloValueSets. +class InstructionValueSet : public ShapeTree { + public: + InstructionValueSet(const Shape& shape) : ShapeTree(shape) {} + + // Return the union of the given InstructionValueSets. + static InstructionValueSet Union( + tensorflow::gtl::ArraySlice inputs); + + string ToString() const; +}; + +std::ostream& operator<<(std::ostream& out, + const InstructionValueSet& instruction_value_set); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index de6081e57e7f27a07b314692c6935ecf3e3c54a9..01fba49bc567900418f9e4622351373abe7b2e18 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -14,10 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { StatusOr HloVerifier::Run(HloModule* module) { + tensorflow::gtl::FlatMap instructions; + for (auto& computation : module->computations()) { for (const auto& instruction : computation->instructions()) { TF_RET_CHECK(instruction->parent() == computation.get()); @@ -30,6 +33,16 @@ StatusOr HloVerifier::Run(HloModule* module) { << " computation: " << computation.get(); } } + + auto previous = instructions.find(instruction->name()); + TF_RET_CHECK(previous == instructions.end()) + << "HLO has name that is not unique within module:\n" + << instruction->ToString() + << " in computation: " << computation->name() + << "\nPrevious HLO with same name:\n" + << previous->second->ToString() + << " in computation: " << previous->second->parent()->name(); + instructions[instruction->name()] = instruction.get(); } } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..1b9a7a297f80cd249fb3dd7a513d785ed3a444d3 --- /dev/null +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" +#include "tensorflow/compiler/xla/metric_table_report.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { + +using tensorflow::strings::Appendf; +using tensorflow::strings::HumanReadableElapsedTime; +using tensorflow::strings::HumanReadableNumBytes; +using tensorflow::strings::StrAppend; + +string HumanReadableProfileBuilder::ToString() const { + string s; + + Appendf(&s, "Execution profile for %s: (%s @ f_nom)\n", + computation_name_.c_str(), + HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); + + auto append_op = [&](const OpInfo& op) { + string bytes_per_sec; + string bytes_per_cycle; + if (op.cycles <= 0 || op.bytes_accessed < 0) { + bytes_per_sec = ""; + bytes_per_cycle = ""; + } else { + bytes_per_sec = + HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)); + bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + } + + double cycles_percent = 0; + if (total_cycles_ > 0) { + cycles_percent = op.cycles / static_cast(total_cycles_) * 100; + } + + double nsecs = op.cycles / clock_rate_ghz_; + Appendf(&s, + "\t%15lld cycles (%6.2f%%) :: %12.1f usec @ f_nom :: %18s " + ":: %12s/s :: %12s/cycle :: %s\n", + op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), + op.flop_count <= 0 + ? "" + : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + }; + + append_op({"[total]", "[total]", /*category=*/"", total_cycles_, -1, -1}); + + // Sort ops in decreasing order of cycles. + std::vector sorted_ops(op_infos_); + std::sort( + sorted_ops.begin(), sorted_ops.end(), + [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); + for (const auto& op : sorted_ops) { + append_op(op); + } + + if (total_cycles_ <= 0) { + StrAppend(&s, "****** 0 total cycles ******\n"); + } else { + MetricTableReport table; + table.SetMetricName("microseconds"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& op : sorted_ops) { + MetricTableReport::Entry entry; + entry.text = op.name; + entry.short_text = op.short_name; + entry.category_text = op.category; + entry.metric = CyclesToMicroseconds(op.cycles); + table.AddEntry(std::move(entry)); + } + StrAppend(&s, table.MakeReport(CyclesToMicroseconds(total_cycles_))); + } + return s; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..1a69cbf8bf3e2f850eb6a284844b2c95678c92f2 --- /dev/null +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HUMAN_READABLE_PROFILE_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HUMAN_READABLE_PROFILE_BUILDER_H_ + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// HumanReadableProfileBuilder helps you create a textual profile of a +// computation, suitable for consumption by humans. +class HumanReadableProfileBuilder { + public: + explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, + int64 total_cycles, + double clock_rate_ghz) + : computation_name_(computation_name.ToString()), + total_cycles_(total_cycles), + clock_rate_ghz_(clock_rate_ghz) { + CHECK_GE(clock_rate_ghz, 1e-9); + } + + int64 total_cycles() const { return total_cycles_; } + + // Adds an operation to the profile. If you don't know the number of + // floating-point ops or bytes touched by the op, pass -1 for that param. + void AddOp(tensorflow::StringPiece op_name, + tensorflow::StringPiece short_name, + tensorflow::StringPiece category, int64 cycles, int64 flop_count, + int64 bytes_accessed) { + op_infos_.push_back({op_name.ToString(), short_name.ToString(), + category.ToString(), cycles, flop_count, + bytes_accessed}); + } + + // Gets the human-readable profile. + string ToString() const; + + private: + struct OpInfo { + string name; + string short_name; + string category; + int64 cycles; + int64 flop_count; + int64 bytes_accessed; + }; + + double CyclesToSeconds(int64 cycles) const { + return cycles / clock_rate_ghz_ / 1e9; + } + double CyclesToMicroseconds(int64 cycles) const { + return cycles / clock_rate_ghz_ / 1000.0; + } + + string computation_name_; + int64 total_cycles_; + double clock_rate_ghz_; + std::vector op_infos_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HUMAN_READABLE_PROFILE_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 2887a8a0a097c9dcb3d490f0845547f104aa1bdf..84bfbb30c30d84a6a233a60fb420b43c3fe3454c 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -51,10 +51,10 @@ TEST_F(InlinerTest, MapMax) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapMaxFunction"); - auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -70,7 +70,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); + auto expected = Literal::CreateR1({4, 3, 3, 4}); LiteralTestUtil::ExpectEqual(*result, *expected); } @@ -83,12 +83,12 @@ TEST_F(InlinerTest, MapConstant) { HloInstruction::CreateParameter(0, r0f32, "x")); (void)param1; const2_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); auto const2_f32 = const2_builder.Build(); auto builder = HloComputation::Builder("MapConstFunction"); auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); @@ -104,7 +104,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); + auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); LiteralTestUtil::ExpectEqual(*result, *expected); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 721640cdbd8133f621f65a2505cdf3b84590e740..24af07bd4bf8d5a61a6092c8eadc5151c09921b4 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { - /*static*/ bool InstructionFusion::IsExpensive( const HloInstruction& instruction) { switch (instruction.opcode()) { @@ -43,6 +42,7 @@ namespace xla { case HloOpcode::kConstant: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: @@ -64,10 +64,12 @@ namespace xla { case HloOpcode::kNegate: case HloOpcode::kOutfeed: case HloOpcode::kPad: + case HloOpcode::kReducePrecision: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSelect: case HloOpcode::kSign: + case HloOpcode::kSin: case HloOpcode::kSlice: case HloOpcode::kSubtract: case HloOpcode::kTranspose: @@ -75,6 +77,8 @@ namespace xla { return false; // Expensive instructions. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormGrad: case HloOpcode::kCall: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: @@ -113,12 +117,111 @@ bool FusionWouldDuplicate(const HloInstruction& producer, const HloInstruction& consumer) { return !(producer.users().size() == 1 && consumer.IsUserOf(&producer)); } + +// An "effectively unary" operation is one that has one "large" +// input with the others being negligible in terms of memory usage. +// We use "has a smaller true rank than the output" as a heuristic +// for "negligible" memory usage. +bool EffectivelyUnary(HloInstruction* hlo) { + int64 output_rank = 0; + ShapeUtil::ForEachSubshape( + hlo->shape(), + [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape)); + } + }); + return std::count_if(hlo->operands().begin(), hlo->operands().end(), + [output_rank](HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kBroadcast) { + return false; + } + if (operand->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(operand->shape())) { + return false; + } + return ShapeUtil::TrueRank(operand->shape()) >= + output_rank; + }) <= 1; +} } // namespace +bool InstructionFusion::CanFuseOnAllPaths( + const HloReachabilityMap& reachability_map, HloInstruction* producer, + HloInstruction* consumer, DoNotFuseSet* do_not_fuse) { + auto could_fuse_on_all_paths = [&] { + // First check to see if we have already marked this producer as infeasible + // to fuse into consumer. + if (do_not_fuse->count(producer) > 0) { + return false; + } + // Make sure it is possible for producer and consumer to exist in a fusion + // node. + if (!producer->IsFusable() || !consumer->IsFusable()) { + return false; + } + // We do an upward walk of the graph from consumer towards all paths which + // lead to producer to find any unfusable paths. + for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { + auto* consumer_operand = consumer->mutable_operand(i); + if (consumer_operand == producer) { + // This is the base case: our upward crawl ends but we need to make sure + // that fusion from consumer can happen. + if (!ShouldFuse(consumer, i)) { + return false; + } + } else if (reachability_map.IsReachable(producer, consumer_operand)) { + // The reachability map told us that consumer_operand is a node on the + // path to producer. We need to further investigate from + // consumer_operand. + + // First check if we have already ruled out fusing producer into + // consumer_operand. + if (do_not_fuse->count(consumer_operand) > 0) { + return false; + } + // Make sure it is possible for consumer_operand to exist in a fusion + // node. + if (!consumer_operand->IsFusable()) { + return false; + } + // The producer is reachable from consumer_operand which means we need + // to be able to fuse consumer_operand into consumer in order for + // producer to be fusable into consumer on all paths. + if (!ShouldFuse(consumer, i)) { + return false; + } + // Perform the recursive step: make sure producer can be fused into + // consumer_operand on all paths. + if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand, + do_not_fuse)) { + return false; + } + } + } + return true; + }; + if (could_fuse_on_all_paths()) { + return true; + } + // We couldn't fuse on all paths, record this result. + do_not_fuse->insert(producer); + return false; +} + StatusOr InstructionFusion::Run(HloModule* module) { bool changed = false; + + std::vector computations; for (auto& computation : module->computations()) { - computation_ = computation.get(); + if (computation->IsFusionComputation()) { + continue; + } + computations.push_back(computation.get()); + } + for (auto& computation : computations) { + CHECK(!computation->IsFusionComputation()); + computation_ = computation; // We want to be able to remove arbitrary instructions from the post order // and also compare positions of instructions in the post order. To make @@ -131,56 +234,42 @@ StatusOr InstructionFusion::Run(HloModule* module) { std::vector post_order(post_order_list.begin(), post_order_list.end()); - std::set all_consumers_fusable; - // Find which ops can be fused into all of their operands. We would rather - // not fuse an op into only some of its users, as that offers no benefit in - // terms of memory bandwidth, but forces us to keep more live values around. - for (auto* hlo : post_order) { - auto user_fusable_into_hlo = [this, &hlo](HloInstruction* consumer) { - if (!consumer->IsFusable()) { - return false; - } - for (int operand_number = 0; - operand_number < consumer->operands().size(); ++operand_number) { - if (consumer->operand(operand_number) == hlo) { - if (!ShouldFuse(consumer, operand_number)) { - return false; - } - } - } - return true; - }; - - // An "effectively unary" operation is one that has one "large" - // input with the others being negligible in terms of memory usage. - // We use "has a smaller true rank than the output" as a heuristic - // for "negligible" memory usage. - auto effectively_unary = [](HloInstruction* hlo) { - if (hlo->operands().size() == 1) { - return true; - } - auto output_rank = ShapeUtil::TrueRank(hlo->shape()); - return std::count_if( - hlo->operands().begin(), hlo->operands().end(), - [output_rank](HloInstruction* operand) { - return ((operand->opcode() != HloOpcode::kBroadcast) && - ShapeUtil::TrueRank(operand->shape()) >= - output_rank); - }) <= 1; - }; - - if (effectively_unary(hlo) || - std::all_of(hlo->users().begin(), hlo->users().end(), - user_fusable_into_hlo)) { - all_consumers_fusable.insert(hlo); - } - } - tensorflow::gtl::FlatMap post_order_index; for (size_t i = 0; i < post_order.size(); ++i) { InsertOrDie(&post_order_index, post_order[i], i); } + DoNotFuseSet do_not_fuse; + auto reachability = computation->ComputeReachability(); + + auto cheap_to_duplicate = [](HloInstruction* producer) { + if (producer->opcode() == HloOpcode::kBroadcast) { + return true; + } + if (producer->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(producer->shape())) { + return true; + } + if (EffectivelyUnary(producer)) { + return true; + } + return false; + }; + + for (HloInstruction* consumer : post_order) { + for (HloInstruction* producer : consumer->operands()) { + if (cheap_to_duplicate(producer)) { + continue; + } + if (CanFuseOnAllPaths(*reachability, producer, consumer, + &do_not_fuse)) { + CHECK_EQ(do_not_fuse.count(producer), 0); + } else { + CHECK_GT(do_not_fuse.count(producer), 0); + } + } + } + // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all // edges. When we fuse an edge, we create a copy of the producer inside the @@ -263,34 +352,36 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); - if (FusionWouldDuplicate(*operand, *instruction) && - (all_consumers_fusable.count(operand) == 0)) { + if (!operand->IsFusable()) { continue; } - - if (operand->IsFusable() && ShouldFuse(instruction, i)) { - HloInstruction* fusion_instruction = Fuse(operand, instruction); - - // Fusing an instruction into a fusion instruction can change the - // operand set of the fusion instruction. For simplicity just push the - // instruction to the top of the post_order and reconsider it for - // further fusion in the next iteration of the outer loop. - post_order.push_back(fusion_instruction); - InsertOrDie(&post_order_index, fusion_instruction, - post_order.size() - 1); - changed = true; - - if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting it's - // location to nullptr. - post_order[FindOrDie(post_order_index, operand)] = nullptr; - post_order_index.erase(operand); - - // Remove from computation. - TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); - } - break; + if (!ShouldFuse(instruction, i)) { + continue; + } + if (do_not_fuse.count(operand) > 0) { + continue; + } + HloInstruction* fusion_instruction = Fuse(operand, instruction); + + // Fusing an instruction into a fusion instruction can change the + // operand set of the fusion instruction. For simplicity just push the + // instruction to the top of the post_order and reconsider it for + // further fusion in the next iteration of the outer loop. + post_order.push_back(fusion_instruction); + InsertOrDie(&post_order_index, fusion_instruction, + post_order.size() - 1); + changed = true; + + if (operand->user_count() == 0) { + // Operand is now dead. Remove from post order by setting it's + // location to nullptr. + post_order[FindOrDie(post_order_index, operand)] = nullptr; + post_order_index.erase(operand); + + // Remove from computation. + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); } + break; } } } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index a9f3723f2dfcc1b3b697d34eb9510f5857a443f0..f6f37bb79b9fe1480db61b10b9810347960f9a72 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -72,6 +72,15 @@ class InstructionFusion : public HloPassInterface { private: HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + // The set of producers whose consumers we cannot fuse into. + using DoNotFuseSet = std::unordered_set; + + // Whether or not we can fuse consumer into original_producer on all paths + // from the producer to the consumer where nodes are HLOs and edges are uses. + bool CanFuseOnAllPaths(const HloReachabilityMap& reachability_map, + HloInstruction* producer, HloInstruction* consumer, + DoNotFuseSet* do_not_fuse); + // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. std::function is_expensive_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index a2e6c2ae00bd65b1d3aeca49f26448d8a07670a8..b3e0007dcc2d43028b49cc48477a0a69153b13c8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -28,7 +28,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndOperandElementReusingConsumerNotFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* broadcast2 = @@ -49,7 +49,7 @@ TEST_F(InstructionFusionTest, NonCostlyProducerAndOperandElementReusingConsumerFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0)); HloInstruction* broadcast2 = @@ -70,7 +70,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* reshape2 = builder.AddInstruction( @@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* transpose2 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index e9e199226a6db7a0547bda4b069e917f2a41295b..7d41be94ce92f0b23c8ef414ea6f4fd9fba7d1a4 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -382,7 +382,11 @@ Status LayoutAssignment::AddMandatoryConstraints( // instruction. // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. - shape_with_layout = &instruction->shape(); + // TODO(b/62477016): When the infeed does not set padding anymore, the + // call to ShapeWithoutPadding can be removed. + Shape infeed_shape = ShapeUtil::ShapeWithoutPadding(instruction->shape()); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(infeed_shape, instruction.get())); } else if (instruction->opcode() == HloOpcode::kOutfeed) { // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. @@ -607,6 +611,9 @@ Status CheckLayouts( TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } for (auto& instruction : computation->instructions()) { // Verify every instruction has a layout and the layout is valid for the // shape. @@ -729,23 +736,18 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kReshape) { // Prefer the operand layout that makes the reshape an bitcast. If any // dimension bound is 1 in the operand shape, there may be several such - // layouts. So if 'output_layout' is a MajorToMinor layout, try if the + // layouts. So if 'output_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy // operations. const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), AsInt64Slice(output_layout.minor_to_major())); - const Shape& operand_shape = operand->shape(); - if (LayoutUtil::IsMonotonicWithDim0Major(output_layout)) { - Shape operand_shape_with_layout = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - operand_shape.element_type(), - AsInt64Slice(operand_shape.dimensions())); - if (ShapeUtil::ReshapeIsBitcast(operand_shape_with_layout, - output_shape_with_layout)) { - return MakeUnique(operand_shape_with_layout.layout()); - } + Shape operand_shape = operand->shape(); + *operand_shape.mutable_layout() = + LayoutUtil::GetDefaultLayoutForShape(operand_shape); + if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { + return MakeUnique(operand_shape.layout()); } auto aligned_operand_shape = ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); @@ -759,10 +761,14 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kTranspose) { // Pick the operand layout that makes the transpose a bitcast. - std::vector perm = - ComposePermutations(instruction->dimensions(), - AsInt64Slice(output_layout.minor_to_major())); - Layout operand_layout = LayoutUtil::MakeLayout(perm); + int64 rank = ShapeUtil::Rank(instruction->shape()); + std::vector new_minor_to_major(rank); + for (int64 i = 0; i < rank; ++i) { + int64 output_dim = output_layout.minor_to_major(i); + int64 operand_dim = instruction->dimensions(output_dim); + new_minor_to_major[i] = operand_dim; + } + Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); return MakeUnique(operand_layout); @@ -789,23 +795,18 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (user->opcode() == HloOpcode::kReshape) { // Prefer the user layout that makes the reshape an bitcast. If any // dimension bound is 1 in the user shape, there may be several such - // layouts. So if 'operand_layout' is a MajorToMinor layout, try if the + // layouts. So if 'operand_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy // operations. Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( operand->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), AsInt64Slice(operand_layout.minor_to_major())); - const Shape& output_shape = user->shape(); - if (LayoutUtil::IsMonotonicWithDim0Major(operand_layout)) { - Shape output_shape_with_layout = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - output_shape.element_type(), - AsInt64Slice(output_shape.dimensions())); - if (ShapeUtil::ReshapeIsBitcast(output_shape_with_layout, - operand_shape_with_layout)) { - return MakeUnique(output_shape_with_layout.layout()); - } + Shape output_shape = user->shape(); + *output_shape.mutable_layout() = + LayoutUtil::GetDefaultLayoutForShape(output_shape); + if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { + return MakeUnique(output_shape.layout()); } auto aligned_user_shape = ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); @@ -818,14 +819,16 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } if (user->opcode() == HloOpcode::kTranspose) { - // Pick the user layout that makes the reshape a bitcast. - // To become a bitcast, the layouts need to satisfy - // collapsing_order * output_layout = input_layout - // so output_layout = inverse(collapsing_order) * input_layout - std::vector perm = - Permute(InversePermutation(user->dimensions()), - AsInt64Slice(operand_layout.minor_to_major())); - Layout user_layout = LayoutUtil::MakeLayout(perm); + // Pick the user layout that makes the transpose a bitcast. + int64 rank = ShapeUtil::Rank(user->shape()); + std::vector new_minor_to_major(rank); + auto inverse_dimensions = InversePermutation(user->dimensions()); + for (int64 i = 0; i < rank; ++i) { + int64 operand_dim = operand_layout.minor_to_major(i); + int64 user_dim = inverse_dimensions[operand_dim]; + new_minor_to_major[i] = user_dim; + } + Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); return MakeUnique(user_layout); } @@ -926,7 +929,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( ShapeUtil::IsArray(buffer->shape())) { TF_RETURN_IF_ERROR(constraints->SetBufferLayout( ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), - *buffer)); + *buffer, /*mandatory=*/true)); } } } @@ -1346,8 +1349,7 @@ StatusOr LayoutAssignment::Run(HloModule* module) { if (VLOG_IS_ON(10)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), "before layout assignment", - /*show_addresses=*/false, - /*show_layouts=*/true); + module->config().debug_options()); } // Assign layouts to computations in an order such that a callee computation @@ -1357,6 +1359,8 @@ StatusOr LayoutAssignment::Run(HloModule* module) { if (computation == module->entry_computation()) { TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_, module->entry_computation())); + } else if (computation->IsFusionComputation()) { + continue; } else { ComputationLayout computation_layout(computation->ComputeProgramShape()); // Setting all embedded computations to the default layout is potentially @@ -1373,8 +1377,7 @@ StatusOr LayoutAssignment::Run(HloModule* module) { if (VLOG_IS_ON(10)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), "after layout assignment", - /*show_addresses=*/false, - /*show_layouts=*/true); + module->config().debug_options()); } // All layouts are reset then reassigned by this pass. diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index ccfc17da4c412c23945630a52fc21cbac87e0727..256d6aa8aa64e3585cb21b3fb2a11c7416c705f1 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -315,6 +315,7 @@ class LayoutAssignment : public HloPassInterface { ComputationLayout* entry_computation_layout_; + protected: // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 6d818cdea0c30701adf83f6265a6d7b554fb91cc..f69c043f32b4e688a543d277164eb91b364b51dc 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -230,7 +230,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateTuple({constant0, constant1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); @@ -264,7 +264,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // tuple and assigning the layouts of the copied arrays as needed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto inner_tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); auto nested_tuple = builder.AddInstruction( @@ -552,6 +552,41 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { ElementsAre(1, 0)); } +// Test layout assignment of a transpose into a bitcast based on its operand. +TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape_with_layout = + ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1})); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(transpose)); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + AssignLayouts(module.get(), &computation_layout); + EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), + transpose->shape(), {2, 3, 0, 1})); +} +// Test layout assignment of a transpose into a bitcast based on its user. +TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7}); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(input_shape, constant, {})); + auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1})); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(transpose)); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + AssignLayouts(module.get(), &computation_layout); + EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), + transpose->shape(), {2, 3, 0, 1})); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 682bf19807b4b5d4e8a66c6c5e2e01c80a026594..9c80fb3adbc99b2e5cd3efc20deaf602c5ebc526 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -28,17 +28,6 @@ limitations under the License. namespace xla { -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()); -} - bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user, @@ -149,18 +138,22 @@ bool HasUniqueFusedUseOfOperandAt( // User and operand can share buffers iff both instructions emit the same shape // and layout, and 'user' meets one of the following qualifications: -// *) Is element-wise. Or... -// *) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. Or... -// *) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion -// instruction where the only use of 'operand' at 'index' in the set -// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... -// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0. +// +// (1) Is element-wise. Or... +// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion +// instruction where the only use of 'operand' at 'index' in the set +// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... +// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index +// 0. +// +// (2) and (3) can only be determined if points-to analysis is available. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { + const TuplePointsToAnalysis* points_to_analysis) { CHECK(user->IsUserOf(operand)) << "user: " << user->ToString() << " operand: " << operand->ToString(); Shape operand_subshape = @@ -170,7 +163,7 @@ bool CanShareOperandBufferWithUser( if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; } - if (user->opcode() == HloOpcode::kFusion) { + if (points_to_analysis != nullptr && user->opcode() == HloOpcode::kFusion) { if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { @@ -180,7 +173,7 @@ bool CanShareOperandBufferWithUser( // 'operand_index', and this singleton use is the fused root at operand // index 0. return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - points_to_analysis); + *points_to_analysis); } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -208,7 +201,7 @@ bool CanShareOperandBufferWithUser( // index 'other_add_operand_index'). return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, other_add_operand_index, - points_to_analysis); + *points_to_analysis); } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h index 0b01223db73d49ad3ee127dd9076e37f5fac8ec5..c7799e5ab5d0c0d0477c09fa7e6a36c67312a72b 100644 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -34,21 +34,16 @@ bool DoesNotUseOperandBuffer(const HloInstruction* operand, const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis); -// Overload which does not require points-to analysis. The result is more -// conservative (returns false more often). -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user); - // Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). -// Returns false otherwise. +// 'operand' (at 'operand_index'). Returns false otherwise. Optionally takes a +// points-to analysis argument. Without the analysis, the result is more +// conservative (returns false more often). // // REQUIRES: 'operand' is an operand of 'user'. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis); + const TuplePointsToAnalysis* points_to_analysis = nullptr); } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index bad4be149a68bdc07a1f7e4ac0668728d10d152e..6a4fde87614750d21cf9572e7f447bba924379c4 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -85,9 +85,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -122,10 +122,10 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { BuildModuleAndRunAnalysis(builder.Build()); - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {}, + points_to_analysis_.get())); + EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, log, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { @@ -143,9 +143,9 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { BuildModuleAndRunAnalysis(builder.Build()); EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *points_to_analysis_)); + points_to_analysis_.get())); EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { @@ -161,10 +161,10 @@ TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { BuildModuleAndRunAnalysis(builder.Build()); - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {}, + points_to_analysis_.get())); + EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, copy, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { @@ -180,9 +180,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -197,9 +197,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // The fusion instruction can share with tuple element 1. EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { @@ -221,12 +221,12 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { // The DynamicUpdateSlice instruction can share with the data operand, but not // with update or starts. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, dus, {}, + points_to_analysis_.get())); + EXPECT_FALSE(CanShareOperandBufferWithUser(update, {}, dus, {}, + points_to_analysis_.get())); + EXPECT_FALSE(CanShareOperandBufferWithUser(starts, {}, dus, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { @@ -234,15 +234,15 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto dot = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -256,7 +256,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { // Output fused dot add should be able to share buffer with 'add_operand'. EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { @@ -264,9 +264,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); @@ -274,7 +274,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -292,7 +292,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { // Output fused transpose-dot-add should be share buffer with 'add_operand'. EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { @@ -300,7 +300,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -308,7 +308,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -320,7 +320,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { // Output fused operand->reverse->add cannot alias operand buffer 'operand'. EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { @@ -360,8 +360,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { RunAnalysis(); // The While instruction can share with the data operand. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, whil, {}, + points_to_analysis_.get())); } } // namespace diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 12b2762f0ed7eb9acce8a60d4501ab6ce53c3b57..61945bd128e68b59bd0a1156882c5b29d6be2a27 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -29,7 +29,6 @@ cc_library( ":ir_array", ":llvm_util", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", @@ -47,7 +46,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -93,6 +91,7 @@ cc_library( deps = [ ":ir_array", ":llvm_loop", + ":ops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 02710ff57f6f75fe6aa1c32670cc7196ae4c402f..1f6932bcc3fb76adb874b963ecf5fb1b16d8a9f4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "external/llvm/include/llvm/IR/MDBuilder.h" -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/types.h" @@ -51,28 +50,37 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, buffer_slice = *slices.begin(); } - llvm::MDNode*& alias_scope_md = alias_scope_metadata_[buffer_slice]; - if (alias_scope_md == nullptr) { - alias_scope_md = - GetAliasScopeMetadataForBuffer(buffer_slice, GetAliasDomain()); + if (module_.config().debug_options().xla_llvm_enable_alias_scope_metadata()) { + llvm::MDNode*& alias_scope_md = alias_scope_metadata_[buffer_slice]; + if (alias_scope_md == nullptr) { + alias_scope_md = + GetAliasScopeMetadataForBuffer(buffer_slice, GetAliasDomain()); + } + array->AddAliasScopeMetadata(alias_scope_md); } - array->AddAliasScopeMetadata(alias_scope_md); - llvm::MDNode*& noalias_md = noalias_metadata_[buffer_slice]; - if (noalias_md == nullptr) { - noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(), - assignment_, hlo); + if (module_.config().debug_options().xla_llvm_enable_noalias_metadata()) { + llvm::MDNode*& noalias_md = noalias_metadata_[buffer_slice]; + if (noalias_md == nullptr) { + noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(), + assignment_, hlo); + } + array->AddNoaliasMetadata(noalias_md); } - array->AddNoaliasMetadata(noalias_md); - // Parameters of the entry computation are never stored to, loading from a - // parameter pointer should always return the same result within a loop. - if (hlo.opcode() == HloOpcode::kParameter) { - const std::vector& parameter_instructions = - module_.entry_computation()->parameter_instructions(); - if (std::find(parameter_instructions.begin(), parameter_instructions.end(), - &hlo) != parameter_instructions.end()) { - array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + if (module_.config() + .debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // Parameters of the entry computation are never stored to, loading from a + // parameter pointer should always return the same result within a loop. + if (hlo.opcode() == HloOpcode::kParameter) { + const std::vector& parameter_instructions = + module_.entry_computation()->parameter_instructions(); + if (std::find(parameter_instructions.begin(), + parameter_instructions.end(), + &hlo) != parameter_instructions.end()) { + array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + } } } } @@ -87,12 +95,6 @@ llvm::MDNode* AliasAnalysis::GetAliasDomain() { llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain) { - legacy_flags::AliasAnalysisFlags* flags = - legacy_flags::GetAliasAnalysisFlags(); - if (!flags->xla_emit_alias_scope) { - return nullptr; - } - // While we could synthesize an alias.scope, doing so is not more profitable // than LLVM's default behavior. if (buffer_slice.allocation() == kParameterAllocation) { @@ -109,12 +111,6 @@ llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain, const BufferAssignment& assignment, const HloInstruction& hlo) { - legacy_flags::AliasAnalysisFlags* flags = - legacy_flags::GetAliasAnalysisFlags(); - if (!flags->xla_emit_alias_scope) { - return nullptr; - } - // We want to construct a list of buffers which: // // 1. Do not alias the given buffer. diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index b259d348708c227a3e580fd352422e457284129d..26e73a6ec390c5823c2a0315480a427ea0a7b373 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -128,6 +128,27 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } +Status FusedIrEmitter::HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) { + std::vector operand_elemental_ir_types; + for (HloInstruction* operand : operands) { + operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( + operand->shape().element_type(), ir_builder_)); + } + generators_[tuple] = + [=](const IrArray::Index& index) -> StatusOr { + llvm::Value* ret = llvm::UndefValue::get(llvm::StructType::get( + ir_builder_->getContext(), operand_elemental_ir_types)); + for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value * val_i, generators_[operands[i]](index)); + ret = ir_builder_->CreateInsertValue(ret, val_i, i); + } + return ret; + }; + return Status::OK(); +} + Status FusedIrEmitter::FinishVisit(HloInstruction* root) { fused_root_ = root; return tensorflow::Status::OK(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 79007b7099a32973cada7a9986ff95c5e4aabec6..1cd8d1194686236dd11f71c56d668708ad113f03 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -54,6 +54,11 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { Status HandleParameter(HloInstruction* parameter) override; + // Emits the ir value for each element in the tuple. + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) override; + Status FinishVisit(HloInstruction* root) override; // Returns the generator function for the root of the fused computation. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e401305ae7342a9db09499c9b3846f5a0a705fa7..75b7856800d2f3e6f279d2ac2bdcf3021bbf4049 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -85,7 +85,7 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) ++depth; } - if (ShapeUtil::Rank(*shape_) == 0) { + if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); @@ -153,6 +153,28 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( return Index(source_multidim_index); } +IrArray::Index IrArray::Index::SourceIndexOfSlice( + const Shape& shape, tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice strides, + llvm::IRBuilder<>* builder) const { + Index source_index(multidim_.size()); + for (int i = 0; i < multidim_.size(); ++i) { + int64 stride = strides[i]; + auto type = multidim_[i]->getType(); + + if (stride != 1) { + source_index[i] = builder->CreateAdd( + builder->CreateMul(multidim_[i], + llvm::ConstantInt::get(type, stride)), + llvm::ConstantInt::get(type, starts[i])); + } else { + source_index[i] = builder->CreateAdd( + multidim_[i], llvm::ConstantInt::get(type, starts[i])); + } + } + return source_index; +} + IrArray::Index IrArray::Index::SourceIndexOfTranspose( const Shape& shape, const Shape& operand_shape, tensorflow::gtl::ArraySlice dimension_mapping, @@ -228,6 +250,18 @@ llvm::Value* IrArray::EmitArrayElementAddress( llvm_ir::AsStringRef(name)); } +void IrArray::AnnotateLoadStoreInstructionWithMetadata( + llvm::Instruction* instruction) const { + CHECK(llvm::isa(instruction) || + llvm::isa(instruction)); + + for (const auto& kind_md_pair : metadata_) { + CHECK(kind_md_pair.first != llvm::LLVMContext::MD_invariant_load || + llvm::isa(instruction)); + instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); + } +} + llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name) const { @@ -236,9 +270,7 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::LoadInst* load = ir_builder->CreateLoad(element_address); llvm_ir::SetTbaaForInstruction(load, GetShape(), /*is_pointer_to=*/false); - for (const auto& kind_md_pair : metadata_) { - load->setMetadata(kind_md_pair.first, kind_md_pair.second); - } + AnnotateLoadStoreInstructionWithMetadata(load); return load; } @@ -248,10 +280,7 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); llvm_ir::SetTbaaForInstruction(store, GetShape(), /*is_pointer_to=*/false); - for (const auto& kind_md_pair : metadata_) { - CHECK_NE(kind_md_pair.first, llvm::LLVMContext::MD_invariant_load); - store->setMetadata(kind_md_pair.first, kind_md_pair.second); - } + AnnotateLoadStoreInstructionWithMetadata(store); } IrArray IrArray::CastToShape(const Shape& new_shape, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 91cb3a679fd67fffb29f8a935cc3c65e9442136b..5fabb1e2433248c0a2fb2a14fb6cb5dacb0dfb39 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -115,6 +115,16 @@ class IrArray { Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape, llvm::IRBuilder<>* builder) const; + // Returns the index into the source operand from which a slice operation + // selects a value to be placed into index "this". The slice is described + // by starting indices `starts` and stride values `strides`. + // + // Precondition: "this" is an index into a slice whose shape is `shape`. + Index SourceIndexOfSlice(const Shape& shape, + tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice strides, + llvm::IRBuilder<>* builder) const; + // Given that "this" is the target index of a transpose from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. Index SourceIndexOfTranspose( @@ -183,6 +193,10 @@ class IrArray { llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name = "") const; + // Attach metadata this IrArray instance knows about to "instruction". + void AnnotateLoadStoreInstructionWithMetadata( + llvm::Instruction* instruction) const; + // Emit IR to read an array element at the given index. Returns the read // result (effectively, a Value loaded from memory). This method seamlessly // handles scalar shapes by broadcasting their value to all indices (index is diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 97f1b8ce30818eaf7465933a28f30959b5e2b90a..0995ed6ff51763e7fcb281c48bd288c44a1f739f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -144,12 +144,19 @@ llvm::BasicBlock* ForLoop::CreateBasicBlockWithSuffix( std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index) { + return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1)); +} + +std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, + llvm::Value* start_index, + llvm::Value* end_index, + llvm::Value* stride) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } - std::unique_ptr loop = ForLoop::EmitForLoop( - suffix, start_index, end_index, ir_builder_->getInt64(1), ir_builder_); + std::unique_ptr loop = + ForLoop::EmitForLoop(suffix, start_index, end_index, stride, ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock(); @@ -172,6 +179,15 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, ir_builder_->getInt64(end_index)); } +std::unique_ptr ForLoopNest::AddLoop(int64 start_index, + int64 end_index, int64 stride, + tensorflow::StringPiece suffix) { + CHECK_LE(start_index, end_index); + return AddLoop(suffix, ir_builder_->getInt64(start_index), + ir_builder_->getInt64(end_index), + ir_builder_->getInt64(stride)); +} + IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, tensorflow::StringPiece suffix) { std::vector dimensions(ShapeUtil::Rank(shape)); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 60ac0444bcde002db6fd6bfa2630c9b78157e491..a66bf80959f6579811fb8b5885d6cd209a48dc7a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -167,12 +167,22 @@ class ForLoopNest { // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have // been added then emit loop inside the body of the last added loop. + std::unique_ptr AddLoop(tensorflow::StringPiece suffix, + llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* stride); + + // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. + std::unique_ptr AddLoop(int64 start_index, int64 end_index, + int64 stride, + tensorflow::StringPiece suffix); + + // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ff2f4cd693ca76c0e4d20522f50a302fb3ae2c40..a8c17a67f159adec94e0f16052c74e53768decc5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -22,11 +22,11 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Operator.h" #include "external/llvm/include/llvm/Target/TargetOptions.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -137,6 +137,24 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) { return result_type; } +StatusOr EncodeSelfDescribingShapeConstant( + const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder) { + string encoded_shape = shape.SerializeAsString(); + if (encoded_shape.size() > std::numeric_limits::max()) { + return InternalError("Encoded shape size exceeded int32 size limit."); + } + *shape_size = static_cast(encoded_shape.size()); + return ir_builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape)); +} + +StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, + int32 size_bytes) { + Shape shape; + TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + return shape; +} + namespace { // Recursively construct a multidimensional LLVM constant which represents the @@ -163,36 +181,36 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, llvm::Constant* value; switch (shape.element_type()) { case PRED: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U8: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case S32: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U32: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case S64: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U64: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case F32: - value = llvm::ConstantFP::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantFP::get(ir_element_type, + literal.Get(*multi_index)); break; case F64: - value = llvm::ConstantFP::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantFP::get(ir_element_type, + literal.Get(*multi_index)); break; default: LOG(FATAL) << "unsupported type " << shape.element_type(); @@ -357,31 +375,9 @@ void EmitLogging(const char* tag, llvm::Value* value, void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, bool is_pointer_to) { - legacy_flags::LlvmUtilFlags* flags = legacy_flags::GetLlvmUtilFlags(); - if (!flags->xla_emit_tbaa) { - return; - } - - llvm::MDBuilder metadata_builder(instruction->getContext()); - llvm::MDNode* root = metadata_builder.createTBAARoot("XLA TBAA"); - string type_name; - if (is_pointer_to) { - type_name += "pointer-to "; - } - // Scalars do not have layout which makes it permissible to omit an explicit - // layout. To make sure that equivalent scalar shapes have the same TBAA, - // remove the (meaningless) explicit layout if one is present. - if (ShapeUtil::Rank(shape) == 0) { - LayoutUtil::ClearLayout(&shape); - } else { - CHECK(shape.has_layout()); - } - type_name += shape.ShortDebugString(); - llvm::MDNode* tbaa_node = - metadata_builder.createTBAANode(llvm_ir::AsStringRef(type_name), root); - instruction->setMetadata(llvm::LLVMContext::MD_tbaa, - metadata_builder.createTBAAStructTagNode( - tbaa_node, tbaa_node, /*Offset=*/0)); + // TODO(b/62903316): TBAA metadata causes LLVM to miscompile generated code, + // most likely because the generated metadata is incorrect. Disable TBAA + // metadata while we resolve this. } void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 7b09c1f83145c2994c381686c2e6343d353becf7..d940c3fcbcfd08bd0e2a44b6721d75273c2aae5e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -106,6 +106,19 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder); +// Returns a value that represents a pointer to a global string constant that +// encodes the shape as a serialized protobuf. +StatusOr EncodeSelfDescribingShapeConstant( + const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder); + +// Inverses the encoding of a Shape protobuf into an LLVM global variable. +// +// This is intended to be called from the runtime to decode the llvm::Constants +// that are created via ConvertShapeToSelfDescribingConstant and subsequently +// embedded into the program. +StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, + int32 size_bytes); + // Converts a given literal to an IR Constant. Literals have known constant // values at IR emission time. llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 9a128b2aa6f2d5e5650624f103c573e671335f7b..8839ec582df844f46f060e26917f15aa297cba3d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -51,8 +52,41 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, shape_(target_array.GetShape()), ir_builder_(ir_builder) {} +LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, + tensorflow::gtl::ArraySlice target_arrays, + llvm::IRBuilder<>* ir_builder) + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) + -> ::tensorflow::Status { + // Convert target_element_generator to a BodyEmitter. + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); + if (target_arrays.size() == 1) { + target_arrays[0].EmitWriteArrayElement(array_index, target_element, + ir_builder); + return tensorflow::Status::OK(); + } + + for (int64 i = 0; i < target_arrays.size(); ++i) { + target_arrays[i].EmitWriteArrayElement( + array_index, ir_builder_->CreateExtractValue(target_element, i), + ir_builder); + } + return tensorflow::Status::OK(); + }), + ir_builder_(ir_builder) { + if (target_arrays.size() > 1) { + // The sanity check for multiple outputs. + shape_ = target_arrays[0].GetShape(); + for (int64 i = 1; i < target_arrays.size(); ++i) { + const Shape& element_shape = target_arrays[i].GetShape(); + CHECK(ShapeUtil::SameDimensions(shape_, element_shape)); + } + } else { + shape_ = target_arrays[0].GetShape(); + } +} + IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock() { - CHECK(!ShapeUtil::IsTuple(shape_)); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 08171e9e9de294339359f86059f89dcf4939ddea..ab6b702c441e04f2c7988a3dcb9880a86ff95355 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -47,6 +47,10 @@ class LoopEmitter { // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder); + // Same as previous method except emits multiple targets in an array. + LoopEmitter(const ElementGenerator& target_element_generator, + tensorflow::gtl::ArraySlice target_arrays, + llvm::IRBuilder<>* ir_builder); LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; virtual ~LoopEmitter() = default; diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 131c2ee87b0e78a4f7e315bfbb2b2793c0a91fa1..45e37c6f65efcff81cbc72737348015ce43a944f 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -46,13 +46,6 @@ namespace se = ::perftools::gputools; namespace xla { -/* static */ StatusOr> LocalService::NewService( - perftools::gputools::Platform* platform) { - ServiceOptions default_options; - default_options.set_platform(platform); - return NewService(default_options); -} - /* static */ StatusOr> LocalService::NewService( const ServiceOptions& options) { perftools::gputools::Platform* platform = options.platform(); @@ -62,7 +55,6 @@ namespace xla { BackendOptions backend_options; backend_options.set_platform(platform) - .set_number_of_replicas(options.number_of_replicas()) .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()); TF_ASSIGN_OR_RETURN(std::unique_ptr backend, Backend::CreateBackend(backend_options)); @@ -70,15 +62,15 @@ namespace xla { TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); std::unique_ptr service(new LocalService( - std::move(backend), std::move(compute_constant_backend))); + options, std::move(backend), std::move(compute_constant_backend))); return std::move(service); } -LocalService::LocalService(std::unique_ptr execute_backend, +LocalService::LocalService(const ServiceOptions& options, + std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) - : Service(std::move(execute_backend), std::move(compute_constant_backend)) { - runs_in_client_process_ = true; -} + : Service(options, std::move(execute_backend), + std::move(compute_constant_backend)) {} namespace { // Returns the space required to allocate a shape. If @@ -152,9 +144,13 @@ StatusOr> LocalService::CompileExecutable( // Construct computation layout from the argument layouts. auto module_config = MakeUnique(*program_shape); module_config->set_has_hybrid_result(has_hybrid_result); - module_config->set_replica_count(execute_backend_->Replicas().size()); - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - if (flags->xla_hlo_profile) { + module_config->set_replica_count(options_.number_of_replicas()); + module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) { + module_config->set_intra_op_parallelism_threads( + execute_backend_->eigen_intra_op_thread_pool()->NumThreads()); + } + if (module_config->debug_options().xla_hlo_profile()) { module_config->enable_hlo_profiling(true); } auto* computation_layout = module_config->mutable_entry_computation_layout(); diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 767a3ab697febb283af448b25369445152381a5e..13797ec0450bd0eb2030b111464c42e966792266 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -35,11 +35,7 @@ namespace xla { // in the same process as the client. class LocalService : public Service { public: - // Factory for creating a LocalService. The parameter platform is the platform - // that the service should target. If platform is null then the default - // platform is used. - static StatusOr> NewService( - perftools::gputools::Platform* platform); + // Factory for creating a LocalService. static StatusOr> NewService( const ServiceOptions& options); @@ -60,7 +56,8 @@ class LocalService : public Service { const Shape* result_layout, int device_ordinal, bool has_hybrid_result); private: - explicit LocalService(std::unique_ptr backend, + explicit LocalService(const ServiceOptions& options, + std::unique_ptr backend, std::unique_ptr compute_constant_backend); LocalService(const LocalService&) = delete; void operator=(const LocalService&) = delete; diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index d24a592f46ed2dd8fd9c927e8ed9816771a7396c..3e843b202997a09f76993acd4d02f4de9aae9854 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -27,7 +27,7 @@ limitations under the License. namespace xla { string LogicalBuffer::ToString() const { - return tensorflow::strings::StrCat(instruction_->FullyQualifiedName(), "[", + return tensorflow::strings::StrCat(instruction_->name(), "[", tensorflow::str_util::Join(index_, ","), "](#", id_, " @", color_.value(), ")"); } diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index 566cd01ea437433e5e328ad523090e682a799233..a9f6688612002f320541b7c1d20df4dd41ea971a 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -95,11 +95,13 @@ class LogicalBuffer { // Functions which return the size and alignment of a logical buffer in bytes. using SizeFunction = std::function; - using AlignmentFunction = std::function; + using AlignmentFunction = std::function; - LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id, - Color color) - : instruction_(instruction), index_(index), id_(id), color_(color) {} + LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) + : instruction_(instruction), + index_(index), + id_(id), + color_(kInvalidColor) {} Id id() const { return id_; } @@ -112,8 +114,19 @@ class LogicalBuffer { // Return the color of the logical buffer. Differently colored buffers can // not be parts of the same allocation. - Color color() const { return color_; } - void set_color(Color color) { color_ = color; } + Color color() const { + CHECK_NE(color_, kInvalidColor) + << "Should not query the color of a buffer that was never colored"; + return color_; + } + + void set_color(Color color) { + CHECK_NE(color, kInvalidColor) + << "Should not set the color of a buffer to the invalid color"; + color_ = color; + } + + bool has_color() const { return color_ != kInvalidColor; } // Return the shape of the buffer. This reference points into the shape field // of the instruction defining the buffer. Therefore, the returned shape will @@ -143,6 +156,8 @@ class LogicalBuffer { static LogicalBufferProto::Location ToLocationProto( const HloInstruction& instruction, const ShapeIndex& index); + const Color kInvalidColor = Color(-1); + private: HloInstruction* instruction_; ShapeIndex index_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 4014856b9b243831a087962484128a121680eb1b..069f85af721228c8f5d40cf243eea7f1e5173c62 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -29,7 +29,11 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { return root; } else { tensorflow::strings::StrAppend(&root, separator_, *count); + // Increment lookup under old 'root' name. (*count)++; + // Initialize count under new 'root' name. + count = &(generated_names_[root]); + *count = 1; return root; } } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3079a0c033844666eeaa3771b467f738af7fb74 --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr ReducePrecisionInsertion::Run(HloModule* module) { + bool changed = false; + VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name(); + + for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + std::vector instructions_to_suffix; + + for (auto& instruction : computation->instructions()) { + VLOG(3) << "Visited instruction: " << instruction->ToString(); + + // For now, ReducePrecision is only implemented for F32 arrays, so this + // ignore instructions that produce other data. In particular, this + // currently ignores instructions producing tuples, even if those tuples + // contain F32 arrays inside them. The assumption is that in most cases + // equivalent behavior can be obtained by adding ReducePrecision + // instructions after the instructions that pull the F32 arrays out of + // the tuples. + if (instruction->shape().element_type() == PrimitiveType::F32 && + !ShapeUtil::IsScalar(instruction->shape()) && + should_reduce_output_precision_(instruction->opcode())) { + instructions_to_suffix.push_back(instruction.get()); + } + } + + for (auto& instruction : instructions_to_suffix) { + HloInstruction* reduced = + computation->AddInstruction(HloInstruction::CreateReducePrecision( + instruction->shape(), instruction, exponent_bits_, + mantissa_bits_)); + TF_RETURN_IF_ERROR( + computation->ReplaceUsesOfInstruction(instruction, reduced)); + VLOG(2) << "Inserted new op after instruction: " + << instruction->ToString(); + changed = true; + } + } + return changed; +} + +ReducePrecisionInsertion::OpcodeFilterFunction +ReducePrecisionInsertion::make_filter_function( + const HloReducePrecisionOptions& reduce_precision_options) { + // Implement the filter function with a lookup table. + std::vector filter(HloOpcodeCount(), false); + for (const auto& opcode : reduce_precision_options.opcodes_to_suffix()) { + filter[opcode] = true; + } + return [filter](const HloOpcode opcode) { + return filter[static_cast(opcode)]; + }; +} + +HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto( + const HloReducePrecisionOptions::PassTiming pass_timing, + const int exponent_bits, const int mantissa_bits, + const OpcodeFilterFunction& should_reduce_output_precision) { + HloReducePrecisionOptions options; + options.set_pass_timing(pass_timing); + options.set_exponent_bits(exponent_bits); + options.set_mantissa_bits(mantissa_bits); + for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) { + if (should_reduce_output_precision(static_cast(opcode))) { + options.add_opcodes_to_suffix(opcode); + } + } + return options; +} + +bool ReducePrecisionInsertion::AddPasses( + HloPassPipeline* pipeline, const DebugOptions& debug_options, + const HloReducePrecisionOptions::PassTiming pass_timing) { + bool passes_added = false; + for (const auto& pass_options : + debug_options.hlo_reduce_precision_options()) { + if (pass_options.pass_timing() == pass_timing) { + pipeline->AddPass(pass_options); + passes_added = true; + } + } + return passes_added; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h new file mode 100644 index 0000000000000000000000000000000000000000..a6fcee0039b449ad265b26fa6acfc912e3ab5731 --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +// HLO pass which inserts reduce-precision instructions into the HLO graph, for +// purposes of experimenting with the effects of reduced-precision storage of +// intermediate values. +class ReducePrecisionInsertion : public HloPassInterface { + using OpcodeFilterFunction = std::function; + + public: + // The exponent_bits and mantissa_bits arguments specify the parameters of + // the instructions to insert. The instructions will be inserted after each + // instruction with an opcode for which the should_reduce_output_precision + // function returns true and the output type is F32. + explicit ReducePrecisionInsertion( + const int exponent_bits, const int mantissa_bits, + const OpcodeFilterFunction& should_reduce_output_precision) + : exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits), + should_reduce_output_precision_(should_reduce_output_precision) {} + + // Version of the constructor that takes an HloReducePrecisionOptions proto + // rather than explicitly-enumerated parameters, for convenience when + // creating passes based on DebugOptions. + explicit ReducePrecisionInsertion( + const HloReducePrecisionOptions& reduce_precision_options) + : exponent_bits_(reduce_precision_options.exponent_bits()), + mantissa_bits_(reduce_precision_options.mantissa_bits()), + should_reduce_output_precision_( + make_filter_function(reduce_precision_options)) {} + + ~ReducePrecisionInsertion() override{}; + + tensorflow::StringPiece name() const override { + return "reduce-precision-insertion"; + } + + // Run the pass on the given module. Returns whether the module was changed + // (reduce-precision instructions were inserted). + StatusOr Run(HloModule* module) override; + + // Convert between the (inconvenient) xla.proto HloReducePrecisionOptions + // representation and OpcodeFilterFunction functions. + static OpcodeFilterFunction make_filter_function( + const HloReducePrecisionOptions& reduce_precision_options); + static HloReducePrecisionOptions make_options_proto( + const HloReducePrecisionOptions::PassTiming pass_timing, + const int exponent_bits, const int mantissa_bits, + const OpcodeFilterFunction& should_reduce_output_precision); + + // Add ReducePrecisionInsertion passes to an HloPassPipeline based on the list + // of HloReducePrecisionOptions in a DebugOptions proto. Returns true if any + // passes were added. + static bool AddPasses( + HloPassPipeline* pipeline, const DebugOptions& debug_options, + const HloReducePrecisionOptions::PassTiming pass_timing); + + private: + // Parameters for the precision reduction to be added. + const int exponent_bits_; + const int mantissa_bits_; + + // Function to determine (from the opcode) whether a given instruction should + // have a reduce-precision instruction inserted in its output stream. + const OpcodeFilterFunction should_reduce_output_precision_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..80717ec2e3f43a968b04dae1367cb7f78fa08b25 --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { + +using ::testing::UnorderedElementsAre; + +class ReducePrecisionInsertionTest : public HloTestBase { + protected: + bool InsertOps(HloModule* module, + const std::function& filter) { + ReducePrecisionInsertion op_insertion(5, 10, filter); + StatusOr result = op_insertion.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(ReducePrecisionInsertionTest, RootInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + + EXPECT_TRUE(InsertOps(module.get(), + [](HloOpcode h) { return h == HloOpcode::kCos; })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(computation->root_instruction(), op::ReducePrecision()); + EXPECT_EQ(computation->root_instruction()->operand(0), b); +} + +TEST_F(ReducePrecisionInsertionTest, NonRootInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a graph with two parameters feeding into unary cosine functions, + // and the output of those feeds into an add function. Feeding the outputs + // from the suffixed cosine functions into a binary add function allows us to + // confirm that the separate operand streams are not crossed when the new + // instructions are inserted. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* a_cos = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* b_cos = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, b)); + + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_EQ(c->operand(0), a_cos); + EXPECT_EQ(c->operand(1), b_cos); + + EXPECT_TRUE(InsertOps(module.get(), + [](HloOpcode h) { return h == HloOpcode::kCos; })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(c->operand(0), op::ReducePrecision()); + EXPECT_EQ(c->operand(0)->operand(0), a_cos); + EXPECT_THAT(c->operand(1), op::ReducePrecision()); + EXPECT_EQ(c->operand(1)->operand(0), b_cos); +} + +TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(S32, {4}); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); + + // Since none of the instructions produce F32 data, this should not change + // the graph. + EXPECT_FALSE(InsertOps(module.get(), [](HloOpcode) { return true; })); + + // Confirm that graph has not changed. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); +} + +TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); + + // Since none of the instructions match the should_reduce_output_precision + // function, this should not change the graph. + EXPECT_FALSE(InsertOps(module.get(), [](HloOpcode h) { return false; })); + + // Confirm that graph has not changed. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); +} + +TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateReducePrecision(shape, a, 9, 23)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + + // This should insert a new ReducePrecision after the existing one, but + // should not then recurse by adding another after the just-inserted one. + EXPECT_TRUE(InsertOps(module.get(), [](HloOpcode h) { + return h == HloOpcode::kReducePrecision; + })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(computation->root_instruction(), op::ReducePrecision()); + EXPECT_EQ(computation->root_instruction()->operand(0), b); +} + +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 2d35ba5e5480511d93a85d8e54ad8983551a329c..1c648d58c7fca25f2cc9069b12532007083cc02d 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -312,10 +312,17 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, StatusOr ReshapeMover::Run(HloModule* module) { bool changed = false; - for (const auto& comp : module->computations()) { + std::vector computations; + for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + computations.push_back(computation.get()); + } + for (const auto& comp : computations) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool did_change, - TrySinkReshapeOrTranspose(comp.get(), instruction)); + TrySinkReshapeOrTranspose(comp, instruction)); changed |= did_change; } } diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 9becdb2bed480d610e658303ee7deff4cf7d2743..1589d52a256df1914201c866859008c0f1df8a8f 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -84,7 +84,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateFromShape(root_shape))); + HloInstruction::CreateConstant(Literal::CreateFromShape(root_shape))); builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); @@ -179,9 +179,8 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{true, true, false}, {false, false, true}}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{true, true, false}, {false, false, true}}))); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1")); @@ -263,12 +262,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); @@ -318,7 +317,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0")); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -352,16 +351,15 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - auto fusion = computation->AddInstruction(HloInstruction::CreateFusion( - add->shape(), HloInstruction::FusionKind::kLoop, add)); - TF_CHECK_OK(computation->ReplaceInstruction(add, fusion)); + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({add}, + HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); @@ -464,7 +462,7 @@ TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {128, 1}), param0)); Array2D a(128, 1024); - auto literal = LiteralUtil::CreateR2FromArray2D(a); + auto literal = Literal::CreateR2FromArray2D(a); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 85ca7e4e59ce9e69a671b829f3c2c3a4834a99ce..25e3f57dfb1c994bd6c96ed6ce18190a0088e963 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -141,12 +141,13 @@ int ServiceOptions::intra_op_parallelism_threads() const { } BackendOptions backend_options; backend_options.set_platform(platform); - backend_options.set_number_of_replicas(options.number_of_replicas()); TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); - std::unique_ptr service(new Service( - std::move(execute_backend), std::move(compute_constant_backend))); + std::unique_ptr service( + new Service(options, std::move(execute_backend), + std::move(compute_constant_backend))); return std::move(service); } @@ -158,24 +159,25 @@ Service::CreateComputeConstantBackend() { if (platform->id() == se::host::kHostPlatformId) { BackendOptions backend_options; backend_options.set_platform(platform); - backend_options.set_number_of_replicas(1); return Backend::CreateBackend(backend_options); } } return NotFound("CPU platform not found"); } -/* static */ Compiler::HloDumper Service::MakeHloDumper() { - return [](const HloModule& module, const string& label) { - return Executable::DumpExecutedHlo(module, label, /*profile=*/nullptr); - }; -} - -Service::Service(std::unique_ptr execute_backend, +Service::Service(const ServiceOptions& options, + std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) - : execute_backend_(std::move(execute_backend)), + : options_(options), + execute_backend_(std::move(execute_backend)), compute_constant_backend_(std::move(compute_constant_backend)) { + CHECK(options_.number_of_replicas() > 0); + if (execute_backend_) { + if (execute_backend_->device_count() > 0) { + CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) + << "Requested more replicas than there are devices."; + } LOG(INFO) << Printf( "XLA service %p executing computations on platform %s. Devices:", this, execute_backend_->platform()->Name().c_str()); @@ -285,7 +287,7 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, Backend* backend) { + const ExecutionOptions& execution_options) { auto module_config = MakeUnique(program_shape); auto* computation_layout = module_config->mutable_entry_computation_layout(); @@ -320,12 +322,11 @@ StatusOr> Service::CreateModuleConfig( shape_with_output_layout)); } - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - if (flags->xla_hlo_profile) { + if (execution_options.debug_options().xla_hlo_profile()) { module_config->enable_hlo_profiling(true); } - module_config->set_replica_count(backend->Replicas().size()); + module_config->set_replica_count(options_.number_of_replicas()); module_config->set_seed(execution_options.seed()); module_config->set_debug_options(execution_options.debug_options()); @@ -341,23 +342,25 @@ StatusOr>> Service::BuildExecutables( // Dump computation proto state if flag is set. std::vector> session_modules; - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - const string& directory_path = flags->xla_dump_computations_to; - const string& other_directory_path = flags->xla_dump_executions_to; - if ((!directory_path.empty() || !other_directory_path.empty())) { - for (int64 i = 0; i < versioned_handles.size(); ++i) { - TF_ASSIGN_OR_RETURN(std::unique_ptr session_module, - computation_tracker_.SnapshotComputation( - versioned_handles[i].handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handles[i].handle.handle(), - session_module->entry().name().c_str(), - versioned_handles[i].version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - session_modules.push_back(std::move(session_module)); - } + for (int64 i = 0; i < versioned_handles.size(); ++i) { + const string& directory_path = + module_configs[i]->debug_options().xla_dump_computations_to(); + const string& other_directory_path = + module_configs[i]->debug_options().xla_dump_executions_to(); + if (directory_path.empty() && other_directory_path.empty()) { + continue; + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr session_module, + computation_tracker_.SnapshotComputation(versioned_handles[i].handle)); + if (!directory_path.empty()) { + string filename = Printf("computation_%lld__%s__version_%lld", + versioned_handles[i].handle.handle(), + session_module->entry().name().c_str(), + versioned_handles[i].version); + TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, + *session_module)); + session_modules.push_back(std::move(session_module)); } } @@ -378,14 +381,12 @@ StatusOr>> Service::BuildExecutables( modules.push_back(std::move(module)); } - Compiler::HloDumper hlo_dumper = MakeHloDumper(); TF_ASSIGN_OR_RETURN( std::vector> executables, - backend->compiler()->Compile(std::move(modules), hlo_dumper, - std::move(executors))); + backend->compiler()->Compile(std::move(modules), std::move(executors))); - if (!other_directory_path.empty()) { - for (size_t i = 0; i < versioned_handles.size(); ++i) { + for (size_t i = 0; i < versioned_handles.size(); ++i) { + if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { executables[i]->set_session_module(std::move(session_modules[i])); } } @@ -405,9 +406,10 @@ StatusOr> Service::BuildExecutable( // Dump computation proto state if flag is set. std::unique_ptr session_module; - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - const string& directory_path = flags->xla_dump_computations_to; - const string& other_directory_path = flags->xla_dump_executions_to; + const string& directory_path = + module_config->debug_options().xla_dump_computations_to(); + const string& other_directory_path = + module_config->debug_options().xla_dump_executions_to(); if (!executable_for_compute_constant && (!directory_path.empty() || !other_directory_path.empty())) { TF_ASSIGN_OR_RETURN( @@ -429,15 +431,9 @@ StatusOr> Service::BuildExecutable( /*include_unreachable_instructions=*/ !executable_for_compute_constant)); - Compiler::HloDumper hlo_dumper = MakeHloDumper(); - if (executable_for_compute_constant && - !flags->xla_hlo_graph_for_compute_constant) { - hlo_dumper = [](const HloModule&, const string&) {}; - } - TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend->compiler()->Compile(std::move(module), hlo_dumper, executor)); + backend->compiler()->Compile(std::move(module), executor)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -495,47 +491,55 @@ Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice< std::vector> arguments, - Backend* backend, - tensorflow::gtl::ArraySlice executors, + Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags) { - // TODO(b/33943292): Support for replication when using multiple computations. - TF_RET_CHECK(backend->Replicas().size() == 1); - - // Set up streams. + // Streams where the computation are launched, so we can wait on the streams + // to complete. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : executors) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - backend->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - // Set up run options. - std::vector run_options; - for (const Pool::SmartPtr& stream : streams) { - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); - options.set_intra_op_thread_pool( - backend->eigen_intra_op_thread_pool_device()); - run_options.emplace_back(options, backend->StreamBorrower()); - } - - // Asynchronously launch all executables. + // Global data handles for the computation results, one for each computation. std::vector result_handles; - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < executables.size(); i++) { - TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase result, - executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i])); - result_handles.push_back(allocation_tracker_.Register( - backend, executors[i]->device_ordinal(), result, - executables[i]->result_shape(), result_tags[i])); + + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + options_.number_of_replicas(), executables.size())); + + for (int64 i = 0; i < executables.size(); i++) { + // Stream executors for the replicas of the current computation. + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + backend->BorrowStream(replicas[replica])); + streams.push_back(std::move(stream)); + + // Set up run options. + ExecutableRunOptions options; + options.set_stream(streams.back().get()); + options.set_allocator(backend->memory_allocator()); + options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); + options.set_intra_op_thread_pool( + backend->eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + ServiceExecutableRunOptions run_options(options, + backend->StreamBorrower()); + + // Asynchronously launch the computation. + TF_ASSIGN_OR_RETURN( + perftools::gputools::DeviceMemoryBase result, + executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + + // All replicas share the same device address for the result allocation, + // so only one of the replicas need to register the result handle. + if (replica == 0) { + result_handles.push_back(allocation_tracker_.Register( + backend, replicas[0]->device_ordinal(), result, + executables[i]->result_shape(), result_tags[i])); + } + } } // Wait for all executions to complete. - for (int64 i = 0; i < result_handles.size(); ++i) { + for (int64 i = 0; i < streams.size(); ++i) { if (!streams[i]->BlockHostUntilDone()) { return InternalError("failed to complete execution for stream %lld", i); } @@ -550,17 +554,23 @@ StatusOr Service::ExecuteAndRegisterResult( arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile) { - TF_RET_CHECK(!backend->Replicas().empty()); - // Set up streams. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : backend->Replicas()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*backend, SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); + for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + options_.number_of_replicas(), + /*computation_count=*/1)); + // Set up run options. std::vector run_options; for (const Pool::SmartPtr& stream : streams) { @@ -570,19 +580,20 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); run_options.emplace_back(options, backend->StreamBorrower(), backend->inter_op_thread_pool()); } perftools::gputools::DeviceMemoryBase result; - if (backend->Replicas().size() == 1) { + if (options_.number_of_replicas() == 1) { TF_ASSIGN_OR_RETURN( result, executable->ExecuteOnStreamWrapper( &run_options[0], profile, arguments)); } else { std::vector< tensorflow::gtl::ArraySlice> - repeated_arguments(backend->Replicas().size(), arguments); + repeated_arguments(options_.number_of_replicas(), arguments); TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( run_options, repeated_arguments)); @@ -610,25 +621,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, std::vector versioned_handles; std::vector> module_configs; std::vector computation_names; + std::vector device_handles; - if (arg->requests_size() > execute_backend_->stream_executors().size()) { + if (arg->requests_size() * options_.number_of_replicas() > + execute_backend_->device_count()) { return FailedPrecondition( "there are not enough stream executors to execute %d computations", arg->requests_size()); } for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor on which the computation will run. Select the - // specific device if requested, otherwise select the i'th device from the - // list of available stream executors. - se::StreamExecutor* executor; - if (arg->requests(i).has_device_handle()) { - executor = - execute_backend_ - ->stream_executors()[arg->requests(i).device_handle().handle()]; - } else { - executor = execute_backend_->stream_executors()[i]; + // Get the stream executor for the i'th computation. This stream executor + // is one of the executors to run the replicated computation. + if (!arg->requests(i).has_device_handle()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); } + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, arg->requests(i).device_handle())); + se::StreamExecutor* executor = replicas[0]; CHECK(executor != nullptr); // Resolve the UserComputation object associated with the requested @@ -662,8 +674,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // the program and the argument allocations. TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, CreateModuleConfig(*program_shape, arg_allocations, - request.execution_options(), - execute_backend_.get())); + request.execution_options())); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -673,6 +684,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, module_configs.push_back(std::move(module_config)); computation_names.push_back(user_computation->name()); executors.push_back(executor); + device_handles.push_back(arg->requests(i).device_handle()); } // Build the user computations into HloModules and compile to generate the @@ -692,7 +704,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::vector outputs, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), executors, + execute_backend_.get(), device_handles, computation_names)); for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; @@ -706,10 +718,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) { - const int64 available_device_count = - execute_backend_->stream_executors().size(); - const int64 replicas = execute_backend_->Replicas().size(); - if (available_device_count < arg->device_count() * replicas) { + const int64 available_device_count = execute_backend_->device_count(); + const int64 replica_count = options_.number_of_replicas(); + if (replica_count <= 0) { + return FailedPrecondition("Replica count must be a positive integer"); + } + if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( "Requested device count (%lld) exceeds the number of available devices " "on the target (%lld)", @@ -718,8 +732,8 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, for (int64 i = 0; i < arg->device_count(); ++i) { DeviceHandle device_handle; - device_handle.set_handle( - execute_backend_->stream_executors()[i * replicas]->device_ordinal()); + device_handle.set_handle(i); + device_handle.set_device_count(arg->device_count()); *result->add_device_handles() = device_handle; } @@ -749,10 +763,9 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options(), execute_backend_.get())); + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options())); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -818,10 +831,9 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options(), execute_backend_.get())); + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options())); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -841,11 +853,14 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, execute_backend_->default_stream_executor(), &profile)); - TF_RET_CHECK(!execute_backend_->Replicas().empty()); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); + // Set up streams. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : execute_backend_->Replicas()) { + for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, execute_backend_->BorrowStream(executor)); streams.push_back(std::move(stream)); @@ -927,19 +942,20 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); - if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { + if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) { // TODO(b/32990684): Tuple transfers to host end up allocating further // buffers - implement that correctly. return Unimplemented( "Tuple transfers to the device not supported with replication."); } - se::StreamExecutor* stream_executor; + std::vector replicas; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(replicas, + Replicas(*execute_backend_, arg->device_handle())); } else { - stream_executor = execute_backend_->default_stream_executor(); + TF_ASSIGN_OR_RETURN( + replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } // Allocate memory on the device, using the stream executor. The size of the @@ -950,14 +966,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, execute_backend_->memory_allocator()->Allocate( - stream_executor->device_ordinal(), allocation_size)); + replicas[0]->device_ordinal(), allocation_size)); *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), stream_executor->device_ordinal(), allocation, - shape, StrCat("TransferToServer literal of size ", allocation_size)); + execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, + StrCat("TransferToServer literal of size ", allocation_size)); - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - stream_executor->device_ordinal())); for (se::StreamExecutor* executor : replicas) { TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( @@ -968,7 +982,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) { - const int64 replica_count = execute_backend_->Replicas().size(); + const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "%s", @@ -980,11 +994,14 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, arg->device_handle())); executor = replicas[arg->replica_id()]; } else { - executor = execute_backend_->Replicas()[arg->replica_id()]; + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); + executor = replicas[arg->replica_id()]; } return execute_backend_->transfer_manager()->TransferLiteralToInfeed( @@ -994,7 +1011,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, tensorflow::Status Service::TransferFromOutfeed( const TransferFromOutfeedRequest* arg, TransferFromOutfeedResponse* result) { - const int64 replica_count = execute_backend_->Replicas().size(); + const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " @@ -1004,11 +1021,14 @@ tensorflow::Status Service::TransferFromOutfeed( se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, arg->device_handle())); executor = replicas[arg->replica_id()]; } else { - executor = execute_backend_->Replicas()[arg->replica_id()]; + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); + executor = replicas[arg->replica_id()]; } Literal literal; @@ -1085,8 +1105,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options, - compute_constant_backend_.get())); + CreateModuleConfig(program_shape, {}, execution_options)); TF_ASSIGN_OR_RETURN( std::shared_ptr executable, @@ -1146,11 +1165,14 @@ tensorflow::Status Service::GetComputationStats( VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); + HloModuleConfig config; + config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN( std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig())); + computation_tracker_.BuildHloModule(versioned_handle, config)); - MakeHloDumper()(*module, "computation statistics subject"); + hlo_graph_dumper::MaybeDumpHloModule(*module, + "computation statistics subject"); // Run HLO analysis to get the computation statistics. HloCostAnalysis analysis( @@ -1166,17 +1188,6 @@ tensorflow::Status Service::GetComputationStats( return tensorflow::Status::OK(); } -tensorflow::Status Service::CheckRunsInClientProcess( - const string& method_name) const { - if (runs_in_client_process_) { - return tensorflow::Status::OK(); - } else { - return FailedPrecondition( - "%s only supported if service runs in the same process as the client", - method_name.c_str()); - } -} - template tensorflow::Status Service::AddInstruction( const RequestT* arg, ResponseT* result, @@ -1195,6 +1206,14 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { StatusOr handle_status; switch (arg->op_case()) { + case OpRequest::kBatchNormTrainingRequest: + handle_status = computation->AddBatchNormTrainingInstruction( + arg->batch_norm_training_request()); + break; + case OpRequest::kBatchNormGradRequest: + handle_status = computation->AddBatchNormGradInstruction( + arg->batch_norm_grad_request()); + break; case OpRequest::kBinaryOpRequest: handle_status = computation->AddBinaryInstruction(arg->binary_op_request()); @@ -1277,6 +1296,11 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { computation->AddReduceInstruction(arg->reduce_request(), *to_apply); break; } + case OpRequest::kReducePrecisionRequest: { + handle_status = computation->AddReducePrecisionInstruction( + arg->reduce_precision_request()); + break; + } case OpRequest::kReduceWindowRequest: { TF_ASSIGN_OR_RETURN(UserComputation * to_apply, computation_tracker_.Resolve( @@ -1383,4 +1407,28 @@ tensorflow::Status Service::LoadComputationSnapshot( return tensorflow::Status::OK(); } +DeviceHandle Service::SingleComputationDeviceHandle() const { + DeviceHandle device_handle; + device_handle.set_handle(0); + device_handle.set_device_count(1); + return device_handle; +} + +StatusOr> Service::Replicas( + const Backend& backend, const DeviceHandle& device_handle) const { + std::vector replicas; + for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { + // From the computation placer, find out the device ids of the replicas for + // the given device handle. + TF_ASSIGN_OR_RETURN( + int device_ordinal, + backend.computation_placer()->DeviceId(replica, device_handle.handle(), + options_.number_of_replicas(), + device_handle.device_count())); + TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal)); + replicas.push_back(executor); + } + return replicas; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index abd1281bdd0ab76297bc64493ec77bbc35fb552b..ccd699516e1b874546340e1650a31067aecb6886 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -22,12 +22,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" #include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -58,8 +57,7 @@ class ServiceOptions { perftools::gputools::Platform* platform() const; // Set the number of replicas to use when compiling replicated - // programs. The default is -1 meaning that the value is read from - // the xla_replicas flag. + // programs. ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; @@ -69,7 +67,7 @@ class ServiceOptions { private: perftools::gputools::Platform* platform_ = nullptr; - int number_of_replicas_ = -1; + int number_of_replicas_ = 1; int intra_op_parallelism_threads_ = -1; }; @@ -126,7 +124,7 @@ class Service : public ServiceInterface { // least N * R devices must be available. The devices are assigned based on // the device ordinals such that the first R available devices are assigned to // the first set of replicas, and the next R devices to the second set of - // replicas, etc. Each returned device handles represent the device with the + // replicas, etc. Each returned device handle represents the device with the // replica id 0. tensorflow::Status GetDeviceHandles( const GetDeviceHandlesRequest* arg, @@ -248,7 +246,7 @@ class Service : public ServiceInterface { // The constructor is private. Use the NewService factory to create new // service objects. - Service(std::unique_ptr backend, + Service(const ServiceOptions& options, std::unique_ptr backend, std::unique_ptr compute_constant_backend); static StatusOr> CreateComputeConstantBackend(); @@ -264,7 +262,7 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, Backend* backend); + const ExecutionOptions& execution_options); // Builds an Executable for the given parameters. If // executable_for_compute_constant is true, then the executable is intended to @@ -319,14 +317,9 @@ class Service : public ServiceInterface { std::vector> arguments, Backend* backend, - tensorflow::gtl::ArraySlice - executors, + tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags); - // Returns an HLO dumper for use in the compiler (it refers to flags - // associated with the service). - static Compiler::HloDumper MakeHloDumper(); - // Convenience function for adding a function to a user computation. template tensorflow::Status AddInstruction( @@ -334,18 +327,24 @@ class Service : public ServiceInterface { const std::function(UserComputation*)>& adder); - // If the service is running in the client process - // (runs_in_client_process_ is true) then return - // tensorflow::Status::OK. Otherwise return an appropriate error - // status with the given method name. Used for "InProcess" methods. - tensorflow::Status CheckRunsInClientProcess(const string& method_name) const; - // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. tensorflow::Status ValidateResultShapeWithLayout( const Shape& shape_with_layout, const Shape& result_shape) const; + // Returns the stream executors assigned to the replicas represented by the + // given device handle. Each device_handle is a virtual replicated device that + // represents a set of physical devices for the replicas. + StatusOr> Replicas( + const Backend& backend, const DeviceHandle& device_handle) const; + + // Returns the device handle that represents the replicated device for a + // single computation that is not model-parallelized. + DeviceHandle SingleComputationDeviceHandle() const; + + ServiceOptions options_; + // Tracks computations built via the API. ComputationTracker computation_tracker_; @@ -369,9 +368,6 @@ class Service : public ServiceInterface { // Backend to use when executing ComputeConstant. std::unique_ptr compute_constant_backend_; - // Whether the service runs in the same process as the client. - bool runs_in_client_process_ = false; - TF_DISALLOW_COPY_AND_ASSIGN(Service); }; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d6436cf988db7632ecf89f1a1e274a0fbab00ce2..40206145c8987083e0b00ceb48ad7a6e7c6cd926 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -184,6 +184,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (operation) { case UNOP_FLOOR: case UNOP_CEIL: + case UNOP_COS: + case UNOP_SIN: case UNOP_EXP: case UNOP_LOG: case UNOP_TANH: @@ -297,6 +299,30 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } +/* static */ StatusOr ShapeInference::InferReducePrecisionShape( + const Shape& operand_shape, const int exponent_bits, + const int mantissa_bits) { + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "expected element type in shape to be floating point for " + "ReducePrecision operation; got %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + if (exponent_bits < 1) { + // One exponent bit is necessary to distinguish 0 from infinity. Having + // no exponent bits doesn't produce a sensible number, so we require at + // least one. + return InvalidArgument("expected exponent_bits >= 1; got %d", + exponent_bits); + } + if (mantissa_bits < 0) { + // A number with no mantissa bits is still meaningful, however. + return InvalidArgument("expected non-negative mantissa_bits; got %d", + mantissa_bits); + } + return operand_shape; +} + /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { @@ -525,9 +551,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementType(lhs, rhs)) { - return InvalidArgument("binary op with different element types: %s and %s", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + return InvalidArgument( + "binary op %s with different element types: %s and %s", + BinaryOperation_Name(operation).c_str(), + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && @@ -754,6 +782,263 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( AsInt64Slice(arg_shape->dimensions())); } +/* static */ StatusOr ShapeInference::InferBatchNormTrainingShape( + const Shape& operand_shape, const Shape& offset_shape, + const Shape& scale_shape, int64 feature_index) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + offset_shape, "offset input of batch norm training")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + scale_shape, "scale input of batch norm training")); + + TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) == + tensorflow::Status::OK()); + + if (feature_index >= ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Expected feature_index of batch-norm-training to be " + "smaller than the rank of operand_shape; " + "got feature_index %lld, and rank %lld", + feature_index, ShapeUtil::Rank(operand_shape)); + } + + if (feature_index < 0) { + return InvalidArgument( + "Expected feature_index of batch-norm-training to " + "be a non-negative number, got %lld", + feature_index); + } + + if (ShapeUtil::Rank(operand_shape) < 1) { + return InvalidArgument( + "Expected the rank of operand to " + "batch-norm-training to be at least 1; got %lld", + ShapeUtil::Rank(operand_shape)); + } + + if (ShapeUtil::Rank(offset_shape) != 1) { + return InvalidArgument( + "Offset input of batch-norm-training must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(offset_shape)); + } + + if (ShapeUtil::Rank(scale_shape) != 1) { + return InvalidArgument( + "Scale input of batch-norm-training must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(scale_shape)); + } + + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "The operand to batch-norm-training must have a floating point " + "element type, but the shape is %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-training, " + "but the shape of offset factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(offset_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-training, " + "but the shape of scale factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(scale_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + const int64 feature_count = operand_shape.dimensions(feature_index); + Shape output_shape_for_mean_and_var = + ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}); + + if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { + return InvalidArgument( + "The size of offset factor should be the same as feature count," + "but the size of offset factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(offset_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + return InvalidArgument( + "The size of scale factor should be the same as feature count," + "but the size of scale factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(scale_shape, 0), feature_count); + } + + return ShapeUtil::MakeTupleShape({operand_shape, + output_shape_for_mean_and_var, + output_shape_for_mean_and_var}); +} + +/* static */ StatusOr ShapeInference::InferBatchNormGradShape( + const Shape& operand_shape, const Shape& scale_shape, + const Shape& mean_shape, const Shape& var_shape, + const Shape& output_grad_shape, int64 feature_index) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad")); + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + output_grad_shape, "output_grad input of batch norm grad")); + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(operand_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(mean_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(scale_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(var_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(output_grad_shape)); + + if (feature_index >= ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Expected feature_index of batch-norm-grad to be " + "smaller than the rank of operand_shape; " + "got feature_index %lld, and rank %lld", + feature_index, ShapeUtil::Rank(operand_shape)); + } + + if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { + return InvalidArgument( + "Expected operand_shape of batch-norm-grad to have the same rank as" + " output_grad_shape; got rank(oprand_shape) %lld, and" + " rank(output_grad_shape) %lld", + ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); + } + + if (ShapeUtil::Rank(mean_shape) != 1) { + return InvalidArgument( + "Mean input of batch-norm-grad must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(mean_shape)); + } + + if (ShapeUtil::Rank(scale_shape) != 1) { + return InvalidArgument( + "Scale input of batch-norm-grad must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(scale_shape)); + } + + if (ShapeUtil::Rank(var_shape) != 1) { + return InvalidArgument( + "Var input of batch-norm-grad must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(var_shape)); + } + + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "The operand to batch-norm-grad must have a floating point " + "element type, but the shape is %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { + return InvalidArgument( + "The output_grad to batch-norm-grad must have a floating point " + "element type, but the shape is %s", + PrimitiveType_Name(output_grad_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-grad, " + "but the element type of output_grad is %s " + "and the element type of operand is %s", + PrimitiveType_Name(output_grad_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-grad, " + "but the element type of scale factor is %s " + "and the element type of operand is %s", + PrimitiveType_Name(scale_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-grad, " + "but the element type of mean is %s " + "and the element type of operand is %s", + PrimitiveType_Name(mean_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(var_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-grad, " + "but the element type of mean is %s " + "and the element type of operand is %s", + PrimitiveType_Name(mean_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + const int64 feature_count = operand_shape.dimensions(feature_index); + + Shape feature_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}); + + if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { + return InvalidArgument( + "The size of mean should be the same as feature count," + "but the size of offset factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(mean_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + return InvalidArgument( + "The size of scale factor should be the same as feature count," + "but the size of scale factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(scale_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) { + return InvalidArgument( + "The size of variance should be the same as feature count," + "but the size of variance is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(var_shape, 0), feature_count); + } + + // Verify operand_shape and output_grad_shape have same bounds. + for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + if (ShapeUtil::GetDimension(operand_shape, i) != + ShapeUtil::GetDimension(output_grad_shape, i)) { + return InvalidArgument( + "The bounds of operand shape should be the same as output_grad's," + "but the bound of operand_shape at dimension %lld is %lld " + "and the bound of output_grad_shape is %lld", + i, ShapeUtil::GetDimension(operand_shape, i), + ShapeUtil::GetDimension(output_grad_shape, i)); + } + } + + return ShapeUtil::MakeTupleShape( + {operand_shape, feature_shape, feature_shape}); +} + /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dnums) { @@ -1019,6 +1304,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( starts.size(), limits.size()); } + if (starts.size() != strides.size()) { + return InvalidArgument("slice start and strides sizes differ: %zu vs %zu", + starts.size(), strides.size()); + } + if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( "slice index count does not match argument rank: %zu vs %lld", @@ -1034,9 +1324,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument("negative start index to slice: %lld", start_index); } - if (stride == 0) { - return InvalidArgument("Zero stride"); - } if (limit_index > arg.dimensions(dimension)) { return InvalidArgument( "limit index (%lld) must be less than or equal to dimension " @@ -1047,17 +1334,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( start_index); VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, limit_index); - if (stride > 0) { - if (start_index > limit_index) { - return InvalidArgument( - "limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index); - } - sizes.push_back((limit_index - start_index + stride - 1) / stride); - } else { - return InvalidArgument("Negative strides not supported"); + if (start_index > limit_index) { + return InvalidArgument( + "limit index (%lld) must be greater or equal to " + "start index (%lld) in slice with positive stride", + limit_index, start_index); } + if (stride <= 0) { + return InvalidArgument("stride (%lld) must be positive", stride); + } + sizes.push_back((limit_index - start_index + stride - 1) / stride); } return ShapeUtil::MakeShape(arg.element_type(), sizes); @@ -1394,10 +1680,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { + string computation_signature = ShapeUtil::HumanString(to_apply); + string argument_shapes = tensorflow::str_util::Join( + arg_shapes, ", ", [](string* out, const Shape* shape) { + tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); + }); return InvalidArgument( "Call applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu", - to_apply.parameters_size(), arg_shapes.size()); + "arity: %d, arguments: %zu; computation signature: %s; argument " + "shapes: [%s]", + to_apply.parameters_size(), arg_shapes.size(), + computation_signature.c_str(), argument_shapes.c_str()); } // All arguments must be compatible with the program shape. diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0d270f99794bd7a17a1df555b9b666a50d4b7e17..f3f0176a434e350cd2be9d3b8c1fe0aa72972433 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -64,6 +64,21 @@ class ShapeInference { tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply); + // Infers the shape produced by InferBatchNormTraining with the given + // operands. + static StatusOr InferBatchNormTrainingShape(const Shape& operand_shape, + const Shape& offset_shape, + const Shape& scale_shape, + int64 feature_index); + + // Infers the shape produced by InferBatchNormGrad with the given operands. + static StatusOr InferBatchNormGradShape(const Shape& operand_shape, + const Shape& scale_shape, + const Shape& mean_shape, + const Shape& var_shape, + const Shape& output_grad_shape, + int64 feature_index); + // Infers the shape produced by applying the given convolutional // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( @@ -165,6 +180,12 @@ class ShapeInference { static StatusOr InferConvertShape(const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the input data type for a reduce-precision operation, + // and returns the result shape. + static StatusOr InferReducePrecisionShape(const Shape& operand_shape, + const int exponent_bits, + const int mantissa_bits); + // Helper that infers the shape produced by a pad operation based on the // padding configuration. static StatusOr InferPadShape(const Shape& operand_shape, diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 15f6b7bfb4a7f507272471c406bd2ade3ab27b20..c79ffa9cd73950b1653f72b1c6286346f76c10fb 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -65,6 +65,17 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; + // Transfer a memory block of the given size from 'source' buffer to the + // Infeed interface of the device using the given executor. + // + // size is the size to transfer from source in bytes. + // + // source is the source data that must be in the target-dependent layout that + // the Infeed HLO used in the computation expects. + virtual Status TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) = 0; + // Transfers the given literal from the Outfeed interface of the device, // using the given executor. virtual Status TransferLiteralFromOutfeed( diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc index ca38601d919adfdfd637dab44796ffa4969cc8f2..29ecef9510cfe6b8764c2e5fe1216255ca1dc983 100644 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc @@ -55,7 +55,7 @@ class CpuTransferManagerTest : public ::testing::Test { TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) { std::vector storage(sizeof(uint32), '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = LiteralUtil::CreateR0(42); + std::unique_ptr literal = Literal::CreateR0(42); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); @@ -66,7 +66,7 @@ TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) { std::vector storage(4 * sizeof(float), '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); std::unique_ptr literal = - LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); @@ -80,7 +80,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) { std::vector storage(16, '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); const char* str = "0123456789abcdef"; - std::unique_ptr literal = LiteralUtil::CreateR1U8(str); + std::unique_ptr literal = Literal::CreateR1U8(str); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index a0c88c6bbc23972bb6a0f3729e51ee0eaee72bc7..585833573606058514d20fa396b433497ec65bd6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -172,7 +172,14 @@ StatusOr TransposeFolding::Run(HloModule* module) { return tensorflow::Status::OK(); }; - for (auto& comp : module->computations()) { + std::vector computations; + for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + computations.push_back(computation.get()); + } + for (auto& comp : computations) { TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); } diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index c72d127ea86e4e9daf99dff4335c538c081f0605..9520c42d280968e3f21a110089583c94277ef1a6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -92,11 +92,11 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { auto builder = HloComputation::Builder("entry_computation"); // 2x1 HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2({{1}, {2}}))); + HloInstruction::CreateConstant(Literal::CreateR2({{1}, {2}}))); // 3x2 HloInstruction* const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); HloInstruction* transpose0 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); @@ -130,11 +130,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { auto builder = HloComputation::Builder("entry"); // (1.0 + 2.0) * (2.0 - 3.0) HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* const3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( const1->shape(), HloOpcode::kAdd, const1, const2)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index ad6f015c70e7241af815246b732fa02768cf0a10..3c4dc19aefa9cb80a25abd916f417e0535ab5171 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -33,9 +33,9 @@ limitations under the License. namespace xla { string BufferAlias::ToString() const { - return tensorflow::strings::StrCat( - "BufferAlias(", instruction_->FullyQualifiedName(), "[", - tensorflow::str_util::Join(index_, ","), "])"); + return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", + tensorflow::str_util::Join(index_, ","), + "])"); } std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { @@ -125,21 +125,19 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, } /* static */ StatusOr> -TuplePointsToAnalysis::Run(const HloModule* module, Colorer colorer) { +TuplePointsToAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( - new TuplePointsToAnalysis(module, std::move(colorer))); + new TuplePointsToAnalysis(module)); TF_RETURN_IF_ERROR(analysis->Analyze()); return std::move(analysis); } -/* static */ StatusOr> -TuplePointsToAnalysis::Run(const HloModule* module) { - return Run(module, DefaultColorer()); -} - Status TuplePointsToAnalysis::Analyze() { points_to_.clear(); for (auto& computation : module_->computations()) { + if (computation->IsFusionComputation()) { + continue; + } TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); @@ -171,9 +169,6 @@ Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases( const ShapeIndex& index, const std::vector& pointed_to_buffers) { for (const LogicalBuffer* buffer : pointed_to_buffers) { - if (buffer_aliases_.count(buffer) == 0) { - buffer_aliases_.insert({buffer, std::vector()}); - } buffer_aliases_[buffer].emplace_back(instruction.get(), index); } }); @@ -184,8 +179,8 @@ Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases( const LogicalBuffer& TuplePointsToAnalysis::NewLogicalBuffer( HloInstruction* instruction, const ShapeIndex& index) { CHECK_EQ(logical_buffers_.size(), next_buffer_id_); - logical_buffers_.push_back(MakeUnique( - instruction, index, next_buffer_id_, colorer_(instruction, index))); + logical_buffers_.push_back( + MakeUnique(instruction, index, next_buffer_id_)); ++next_buffer_id_; return *logical_buffers_.back(); } @@ -243,12 +238,11 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( return Status::OK(); } -Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) { // A kCopy instruction performs a shallow copy of the operand. The top-level // buffer (index={}) is newly created, but all other buffers (in the case of a // tuple shape) come from the operand - PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, operand); + PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0)); points_to_set.mutable_element(/*index=*/{})->clear(); points_to_set.AddPointedToBuffer(NewLogicalBuffer(copy, /*index=*/{}), /*index=*/{}); @@ -343,9 +337,11 @@ const PointsToSet& TuplePointsToAnalysis::GetPointsToSet( PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( const HloInstruction* instruction) { - CHECK_EQ(0, points_to_.count(instruction)); - points_to_[instruction] = MakeUnique(instruction->shape()); - return *FindOrDie(points_to_, instruction); + auto set = MakeUnique(&instruction->shape()); + auto res = points_to_.emplace(instruction, std::move(set)); + CHECK(res.second) << "instruction should not have been present in the map."; + // Return *set using the iterator returned by emplace. + return *res.first->second; } bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( @@ -458,6 +454,9 @@ string TuplePointsToAnalysis::ToString() const { string output = tensorflow::strings::Printf( "TuplePointsToSet for module %s:\n", module_->name().c_str()); for (const auto& computation : module_->computations()) { + if (computation->IsFusionComputation()) { + continue; + } const char* entry = computation.get() == module_->entry_computation() ? "entry " : ""; tensorflow::strings::StrAppend(&output, entry, "computation ", diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 4d7fc7cbc9e5ba2ac87dc6fd10691ce308b827f6..099713d671dec21019d9fb3af767b81603570999 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -48,7 +48,10 @@ namespace xla { // the corresponding buffer. class PointsToSet : public ShapeTree> { public: - explicit PointsToSet(const Shape& shape) + // Construct our ShapeTree with a pointer rather than a reference to a Shape + // because this is very hot code, and copying (and then destroying) all these + // Shapes is slow. + explicit PointsToSet(const Shape* shape) : ShapeTree>(shape), tuple_sources_(shape) {} @@ -142,15 +145,7 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); // the potential sources of each buffer in each instruction's output. class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { public: - using Colorer = std::function; - - // Runs points-to analysis on 'module' with the provided buffer color - // assigner. - static StatusOr> Run( - const HloModule* module, Colorer colorer); - - // Runs points-to analysis on 'module' with the default color assigner. + // Runs points-to analysis on 'module'. static StatusOr> Run( const HloModule* module); @@ -208,23 +203,15 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) override; string ToString() const; - static Colorer DefaultColorer() { - return [](const HloInstruction* instruction, const ShapeIndex& index) { - return LogicalBuffer::Color(0); - }; - } - private: - explicit TuplePointsToAnalysis(const HloModule* module, - Colorer colorer = DefaultColorer()) - : module_(module), colorer_(colorer) {} + explicit TuplePointsToAnalysis(const HloModule* module) : module_(module) {} // Perform the analysis. Should be called immediately after constructing the // object and before calling GetPointsToSet. @@ -283,9 +270,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // The ID of the next logical buffer created. LogicalBuffer::Id next_buffer_id_ = 0; - // Used to color the created logical buffers. - Colorer colorer_; - TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis); }; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 9909c11929d4b2ecf632ab644981a039446bdfc8..cd79e63cafcfecce71cf3380aba9e409da0e72c8 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase { TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant, constant, constant})); @@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { // the same. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto copy = builder.AddInstruction( @@ -318,16 +318,16 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // set containing the union of both sides. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -356,7 +356,7 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, param0, param1)); auto copy = builder.AddInstruction( @@ -396,16 +396,16 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { // Select from two identical tuples. The result should not be ambiguous. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -427,9 +427,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { // the right values. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto inner_tuple2 = builder.AddInstruction( @@ -441,7 +441,7 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -474,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { // have the operand of the bitcast in its points-to set. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( constant2->shape(), HloOpcode::kBitcast, constant2)); auto tuple = @@ -510,10 +510,9 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()}))); + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), + Literal::CreateR1({2.0, 42}).get()}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -533,9 +532,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) { // times. Verify buffer alias sets. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple = builder.AddInstruction( @@ -574,7 +573,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { auto tuple_element1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1)); auto ones = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones) auto update = builder.AddInstruction(HloInstruction::CreateBinary( update_shape, HloOpcode::kAdd, tuple_element1, ones)); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 4aba8875161c9a2d12668d57ea55ded066d38da0..3ab780e7d0b5f5c0af482d5d452d9a97641e1b54 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -49,6 +48,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kAbs; case UNOP_CEIL: return HloOpcode::kCeil; + case UNOP_COS: + return HloOpcode::kCos; case UNOP_EXP: return HloOpcode::kExp; case UNOP_FLOOR: @@ -63,6 +64,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kNegate; case UNOP_SIGN: return HloOpcode::kSign; + case UNOP_SIN: + return HloOpcode::kSin; case UNOP_SORT: return HloOpcode::kSort; case UNOP_TANH: @@ -465,6 +468,90 @@ StatusOr UserComputation::AddReduceInstruction( return handle; } +StatusOr +UserComputation::AddBatchNormTrainingInstruction( + const BatchNormTrainingRequest& batch_norm_training_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(batch_norm_training_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* scale, + LookUpRequest(batch_norm_training_request.scale())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* offset, + LookUpRequest(batch_norm_training_request.offset())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferBatchNormTrainingShape( + operand->output_shape(), scale->output_shape(), + offset->output_shape(), batch_norm_training_request.feature_index())); + + *request.mutable_output_shape() = inferred_shape; + + *request.mutable_output_handle() = handle; + + *request.mutable_request()->mutable_batch_norm_training_request() = + batch_norm_training_request; + + VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << batch_norm_training_request.ShortDebugString(); + + return handle; +} + +StatusOr UserComputation::AddBatchNormGradInstruction( + const BatchNormGradRequest& batch_norm_grad_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(batch_norm_grad_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* scale, + LookUpRequest(batch_norm_grad_request.scale())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* mean, + LookUpRequest(batch_norm_grad_request.mean())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* variance, + LookUpRequest(batch_norm_grad_request.variance())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output, + LookUpRequest(batch_norm_grad_request.grad_output())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferBatchNormGradShape( + operand->output_shape(), scale->output_shape(), mean->output_shape(), + variance->output_shape(), grad_output->output_shape(), + batch_norm_grad_request.feature_index())); + + *request.mutable_output_shape() = inferred_shape; + + *request.mutable_output_handle() = handle; + + *request.mutable_request()->mutable_batch_norm_grad_request() = + batch_norm_grad_request; + + VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << batch_norm_grad_request.ShortDebugString(); + + return handle; +} + StatusOr UserComputation::AddReduceWindowInstruction( const ReduceWindowRequest& reduce_window_request, const UserComputation& to_apply_computation) { @@ -841,6 +928,34 @@ StatusOr UserComputation::AddConvertInstruction( return handle; } +StatusOr UserComputation::AddReducePrecisionInstruction( + const ReducePrecisionRequest& reduce_precision_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(reduce_precision_request.operand())); + + TF_ASSIGN_OR_RETURN( + Shape new_shape, + ShapeInference::InferReducePrecisionShape( + operand->output_shape(), reduce_precision_request.exponent_bits(), + reduce_precision_request.mantissa_bits())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_reduce_precision_request() = + reduce_precision_request; + + VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << reduce_precision_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddConvolveInstruction( const ConvolveRequest& convolve_request) { tensorflow::mutex_lock lock(mutex_); @@ -897,9 +1012,6 @@ StatusOr UserComputation::AddInfeedInstruction( tensorflow::mutex_lock lock(mutex_); const Shape& shape = infeed_request.shape(); - if (ShapeUtil::IsNestedTuple(shape)) { - return InvalidArgument("Infeed does not support nested tuple shapes"); - } if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); } @@ -923,9 +1035,6 @@ Status UserComputation::AddOutfeedInstruction( tensorflow::mutex_lock lock(mutex_); const Shape& shape = outfeed_request.shape(); - if (ShapeUtil::IsNestedTuple(shape)) { - return InvalidArgument("Outfeed does not support nested tuple shapes"); - } if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Outfeed must have a layout"); } @@ -1556,6 +1665,36 @@ void ConstantVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + ConstantVisitor(session_computation, + batch_norm_training_request.operand(), visited, + is_constant); + ConstantVisitor(session_computation, batch_norm_training_request.scale(), + visited, is_constant); + ConstantVisitor(session_computation, batch_norm_training_request.offset(), + visited, is_constant); + break; + } + + case OpRequest::kBatchNormGradRequest: { + const BatchNormGradRequest& batch_norm_grad_request = + request.request().batch_norm_grad_request(); + ConstantVisitor(session_computation, batch_norm_grad_request.operand(), + visited, is_constant); + ConstantVisitor(session_computation, batch_norm_grad_request.scale(), + visited, is_constant); + ConstantVisitor(session_computation, batch_norm_grad_request.mean(), + visited, is_constant); + ConstantVisitor(session_computation, batch_norm_grad_request.variance(), + visited, is_constant); + ConstantVisitor(session_computation, + batch_norm_grad_request.grad_output(), visited, + is_constant); + break; + } + case OpRequest::kBinaryOpRequest: { const BinaryOpRequest& binary_op_request = request.request().binary_op_request(); @@ -1824,7 +1963,6 @@ Status UserComputation::CheckParametersAreContiguous( } } - auto program_shape = MakeUnique(); for (int64 i = 0; i < parameter_requests.size(); ++i) { auto it = parameter_requests.find(i); if (it == parameter_requests.end()) { @@ -1850,26 +1988,31 @@ class ComputationLowerer { const SessionComputation& session_computation, VersionedComputationHandle::Version version, UserComputation::HloComputationResolver hlo_resolver, + const DebugOptions& debug_options, bool include_unreachable_instructions) { ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver)); - return lowerer.Lower(include_unreachable_instructions); + std::move(hlo_resolver), debug_options, + include_unreachable_instructions); + return lowerer.Lower(); } private: ComputationLowerer(const string& computation_name, const SessionComputation& session_computation, VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver) + UserComputation::HloComputationResolver hlo_resolver, + const DebugOptions& debug_options, + bool include_unreachable_instructions) : hlo_builder_(computation_name), session_computation_(session_computation), version_(version), - hlo_resolver_(std::move(hlo_resolver)) {} + hlo_resolver_(std::move(hlo_resolver)), + debug_options_(debug_options), + include_unreachable_instructions_(include_unreachable_instructions) {} // Build an HLO computation from the SessionComputation at the given // version. - StatusOr> Lower( - bool include_unreachable_instructions); + StatusOr> Lower(); private: // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. @@ -1899,6 +2042,8 @@ class ComputationLowerer { const SessionComputation& session_computation_; const VersionedComputationHandle::Version version_; const UserComputation::HloComputationResolver hlo_resolver_; + const DebugOptions& debug_options_; + const bool include_unreachable_instructions_; }; // Calls 'apply' on each operand of 'request'. @@ -1964,6 +2109,28 @@ static void ForEachOperand( break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + + apply(batch_norm_training_request.operand()); + apply(batch_norm_training_request.scale()); + apply(batch_norm_training_request.offset()); + break; + } + + case OpRequest::kBatchNormGradRequest: { + const BatchNormGradRequest& batch_norm_grad_request = + request.request().batch_norm_grad_request(); + + apply(batch_norm_grad_request.operand()); + apply(batch_norm_grad_request.scale()); + apply(batch_norm_grad_request.mean()); + apply(batch_norm_grad_request.variance()); + apply(batch_norm_grad_request.grad_output()); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); @@ -2117,6 +2284,13 @@ static void ForEachOperand( break; } + case OpRequest::kReducePrecisionRequest: { + const ReducePrecisionRequest& reduce_precision_request = + request.request().reduce_precision_request(); + apply(reduce_precision_request.operand()); + break; + } + case OpRequest::kTraceRequest: { const TraceRequest& trace_request = request.request().trace_request(); apply(trace_request.operand()); @@ -2175,8 +2349,7 @@ void ComputationLowerer::TraversePostorder( } } -StatusOr> ComputationLowerer::Lower( - bool include_unreachable_instructions) { +StatusOr> ComputationLowerer::Lower() { // Map from ComputationDataHandle to HLO instruction. Serves as a record of // which operations have been visited as well as a cache for looking up // ComputationDataHandles as HloInstructions. @@ -2192,7 +2365,7 @@ StatusOr> ComputationLowerer::Lower( HloInstruction* hlo_root = instructions.at(root_request->output_handle().handle()); - if (include_unreachable_instructions) { + if (include_unreachable_instructions_) { // Iterate through all computation data handles, and visit any unvisited // operations. for (int64 request_num = 1; request_num <= version_; ++request_num) { @@ -2276,7 +2449,7 @@ void ComputationLowerer::Visit( const ConstantRequest& constant_request = request.request().constant_request(); hlo_instruction = add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CloneToUnique(Literal(constant_request.literal())))); + Literal(constant_request.literal()).CloneToUnique())); break; } @@ -2457,6 +2630,44 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + HloInstruction* operand = + lookup_instruction(batch_norm_training_request.operand()); + HloInstruction* scale = + lookup_instruction(batch_norm_training_request.scale()); + HloInstruction* offset = + lookup_instruction(batch_norm_training_request.offset()); + + hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( + request.output_shape(), operand, scale, offset, + batch_norm_training_request.epsilon(), + batch_norm_training_request.feature_index())); + break; + } + + case OpRequest::kBatchNormGradRequest: { + const BatchNormGradRequest& batch_norm_grad_request = + request.request().batch_norm_grad_request(); + + HloInstruction* operand = + lookup_instruction(batch_norm_grad_request.operand()); + HloInstruction* scale = + lookup_instruction(batch_norm_grad_request.scale()); + HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean()); + HloInstruction* variance = + lookup_instruction(batch_norm_grad_request.variance()); + HloInstruction* grad_output = + lookup_instruction(batch_norm_grad_request.grad_output()); + + hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad( + request.output_shape(), operand, scale, mean, variance, grad_output, + batch_norm_grad_request.epsilon(), + batch_norm_grad_request.feature_index())); + break; + } + case OpRequest::kBroadcastRequest: { const BroadcastRequest& broadcast_request = request.request().broadcast_request(); @@ -2670,8 +2881,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = @@ -2688,6 +2898,18 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kReducePrecisionRequest: { + const ReducePrecisionRequest& reduce_precision_request = + request.request().reduce_precision_request(); + HloInstruction* operand = + lookup_instruction(reduce_precision_request.operand()); + auto exponent_bits = reduce_precision_request.exponent_bits(); + auto mantissa_bits = reduce_precision_request.mantissa_bits(); + hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( + request.output_shape(), operand, exponent_bits, mantissa_bits)); + break; + } + case OpRequest::kTraceRequest: { const TraceRequest& trace_request = request.request().trace_request(); HloInstruction* operand = lookup_instruction(trace_request.operand()); @@ -2718,7 +2940,7 @@ void ComputationLowerer::Visit( StatusOr> UserComputation::BuildHloComputation( VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, + HloComputationResolver hlo_resolver, const DebugOptions& debug_options, bool include_unreachable_instructions) const { tensorflow::mutex_lock lock(mutex_); @@ -2730,7 +2952,7 @@ StatusOr> UserComputation::BuildHloComputation( std::unique_ptr hlo_computation, ComputationLowerer::Lower( tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), + session_computation_, version, std::move(hlo_resolver), debug_options, include_unreachable_instructions)); XLA_VLOG_LINES(2, hlo_computation->ToString()); diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index fb5425ae61ab1edcd00aac493c9e2ac3c430cb72..36b1d34e05d7ef4d9d6b5d0f76822b6813d117e8 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -84,6 +85,14 @@ class UserComputation { StatusOr AddUnaryInstruction( const UnaryOpRequest& unary_request); + // Enqueues a batch norm training instruction onto this user computation. + StatusOr AddBatchNormTrainingInstruction( + const BatchNormTrainingRequest& batch_norm_training_request); + + // Enqueues a batch norm grad instruction onto this user computation. + StatusOr AddBatchNormGradInstruction( + const BatchNormGradRequest& batch_norm_grad_request); + // Enqueues a binary instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. StatusOr AddBinaryInstruction( @@ -112,6 +121,10 @@ class UserComputation { const MapRequest& map_request, const UserComputation& to_apply_computation); + // Enqueues a reduce-precision instruction onto this user computation. + StatusOr AddReducePrecisionInstruction( + const ReducePrecisionRequest& reduce_precision_request); + // Enqueues a convolution instruction onto this user computation. StatusOr AddConvolveInstruction( const ConvolveRequest& convolve_request); @@ -256,7 +269,7 @@ class UserComputation { std::function; StatusOr> BuildHloComputation( VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, + HloComputationResolver hlo_resolver, const DebugOptions& debug_options, bool include_unreachable_instructions = true) const; // Return a vector containing the embedded computations used by this diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index ea691201263e4935afbc29bcb8624a73c6715f83..07739f241aa01eacf83630c72aec6199b66b49d4 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -50,16 +50,16 @@ TEST_F(UserComputationTest, SimpleComputation) { ConstantRequest constant_request; *constant_request.mutable_literal() = - LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle, - computation.AddConstantInstruction(constant_request)); + Literal::CreateR1({123.0f, 42.0f})->ToProto(); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, + computation.AddConstantInstruction(constant_request)); ParameterRequest param_request; *param_request.mutable_shape() = kScalarShape; param_request.set_parameter(0); param_request.set_name("param0"); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle param_handle, - computation.AddParameterInstruction(param_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, + computation.AddParameterInstruction(param_request)); OpMetadata metadata; metadata.set_op_name("meta"); TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); @@ -81,7 +81,7 @@ TEST_F(UserComputationTest, SimpleComputation) { // Program shape should have a single scalar parameter and scalar // result. The outfeed instruction should not affect the program shape. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::shared_ptr program_shape, computation.ComputeProgramShape(latest_version.version)); ASSERT_EQ(1, program_shape->parameters_size()); @@ -90,9 +90,10 @@ TEST_F(UserComputationTest, SimpleComputation) { EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // There should be one HloInstruction per UserComputation operation. EXPECT_EQ(3, hlo_computation->instruction_count()); // The root of the instruction should be the parameter instruction (not the @@ -107,7 +108,7 @@ TEST_F(UserComputationTest, SimpleComputation) { computation.GetVersionedHandleAtOperation(param_handle); // Program shape should have a single scalar parameter, and scalar result. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::shared_ptr program_shape, computation.ComputeProgramShape(version_at_param.version)); ASSERT_EQ(1, program_shape->parameters_size()); @@ -117,9 +118,10 @@ TEST_F(UserComputationTest, SimpleComputation) { // There should be two instructions, one for the constant and one for the // parameter. The outfeed instruction should not be included. - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr hlo_computation, - computation.BuildHloComputation( - version_at_param.version, hlo_resolver)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_computation, + computation.BuildHloComputation(version_at_param.version, hlo_resolver, + DebugOptions())); EXPECT_EQ(2, hlo_computation->instruction_count()); EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); } @@ -130,10 +132,11 @@ TEST_F(UserComputationTest, SimpleComputation) { computation.GetVersionedHandle(); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, - /*include_unreachable_instructions=*/false)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_computation, + computation.BuildHloComputation( + latest_version.version, hlo_resolver, DebugOptions(), + /*include_unreachable_instructions=*/false)); // There is only one reachable instruction, the parameter. EXPECT_EQ(1, hlo_computation->instruction_count()); // The root of the instruction should be the parameter instruction (not the @@ -145,8 +148,8 @@ TEST_F(UserComputationTest, SimpleComputation) { } TEST_F(UserComputationTest, EliminateScalarBroadcast) { - if (!legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (!legacy_flags::GetDebugOptionsFromFlags() + .xla_eliminate_hlo_implicit_broadcast()) { return; } @@ -161,14 +164,14 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { ConstantRequest a_request; *a_request.mutable_literal() = - LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, - computation.AddConstantInstruction(a_request)); + Literal::CreateR1({123.0f, 42.0f})->ToProto(); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, + computation.AddConstantInstruction(a_request)); ConstantRequest b_request; - *b_request.mutable_literal() = LiteralUtil::CreateR0(1.0f)->ToProto(); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, - computation.AddConstantInstruction(b_request)); + *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, + computation.AddConstantInstruction(b_request)); BinaryOpRequest add; add.set_binop(BINOP_ADD); @@ -182,9 +185,10 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { VersionedComputationHandle latest_version = computation.GetVersionedHandle(); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // The binary operation has implicit scalar broadcast, should be converted // to an explicit broadcast intruction and a binary instruction. EXPECT_EQ(4, hlo_computation->instruction_count()); @@ -196,8 +200,8 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { } TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - if (!legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (!legacy_flags::GetDebugOptionsFromFlags() + .xla_eliminate_hlo_implicit_broadcast()) { return; } @@ -214,15 +218,15 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); a_request.set_name("a"); a_request.set_parameter(0); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, + computation.AddParameterInstruction(a_request)); ParameterRequest b_request; *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); b_request.set_name("b"); b_request.set_parameter(1); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, + computation.AddParameterInstruction(b_request)); BinaryOpRequest add; add.set_binop(BINOP_ADD); @@ -238,9 +242,10 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { VersionedComputationHandle latest_version = computation.GetVersionedHandle(); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // The binary operation has in-dim broadcast and degenerate broadcast, should // first do the in-dim broadcast then convert the degnerate broadcast into a @@ -266,7 +271,7 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendUserComputationFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index cc456df4fce5c78162c41ed36f6c69c0f5ab459b..81cdbf5117f2d16e5a871849a7875b1746baf42a 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -81,19 +82,56 @@ struct ShapeTreeNode { // Like the Shape data structure, this is a tree and tuple elements cannot be // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T // object. +// +// Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes +// it's helpful not to copy a Shape just to make a ShapeTree. In these cases, +// you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's +// then up to you to ensure that the pointed-to Shape doesn't die or mutate +// before its ShapeTree goes away. template class ShapeTree { public: // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} + // Create ShapeTree with the given shape, and default-constructed T values for // all nodes. - explicit ShapeTree(const Shape& shape); + // + // The version that takes a pointer may be cheaper because it doesn't require + // any Shape copies, but then it's up to you to ensure that the pointer stays + // alive longer than this ShapeTree. + explicit ShapeTree(Shape shape); + explicit ShapeTree(const Shape* shape); + // Create ShapeTree with the given shape, and init_value for all nodes. - ShapeTree(const Shape& shape, const T& init_value); + ShapeTree(Shape shape, const T& init_value); + ShapeTree(const Shape* shape, const T& init_value); + + ShapeTree(const ShapeTree& other) + : root_(other.root_), shape_storage_(other.shape_storage_) { + // Fix up internal pointer if necessary. + if (shape_storage_) { + CHECK_EQ(other.shape_, &*other.shape_storage_); + shape_ = &*shape_storage_; + } else { + shape_ = other.shape_; + } + } - ShapeTree(const ShapeTree& other) = default; - ShapeTree& operator=(const ShapeTree& other) = default; + ShapeTree& operator=(const ShapeTree& other) { + root_ = other.root_; + shape_storage_ = other.shape_storage_; + + // Fix up internal pointer if necessary. + if (shape_storage_) { + CHECK_EQ(other.shape_, &*other.shape_storage_); + shape_ = &*shape_storage_; + } else { + shape_ = other.shape_; + } + + return *this; + } // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). @@ -101,7 +139,7 @@ class ShapeTree { T* mutable_element(const ShapeIndex& index); // Return the shape represented with this ShapeTree. - const Shape& shape() const { return shape_; } + const Shape& shape() const { return *shape_; } // Returns true if the node at the given index is a leaf node (an array // shape). @@ -112,27 +150,27 @@ class ShapeTree { // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // + // Fn : A callable of type void(const ShapeIndex& index, const T& data) + // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. // data : The data value at this elemnt. - using VisitorFunction = - std::function; - void ForEachElement(const VisitorFunction& func) const; - - using MutableVisitorFunction = - std::function; - void ForEachMutableElement(const MutableVisitorFunction& func); + template + void ForEachElement(const Fn& func) const; - // Variants of ForEach(Mutable)Element which propagate a Status value from the - // visitor. - using StatusVisitorFunction = - std::function; - Status ForEachElementWithStatus(const StatusVisitorFunction& func) const; + // Like ForEachElement, but the callable has type + // + // void (const ShapeIndex& index, T* data). + // + template + void ForEachMutableElement(const Fn& func); - using MutableStatusVisitorFunction = - std::function; - Status ForEachMutableElementWithStatus( - const MutableStatusVisitorFunction& func); + // Like ForEach(Mutable)Element, but the callable returns a Status instead of + // void. The first non-OK return value is returned by the ForEach* function. + template + Status ForEachElementWithStatus(const Fn& func) const; + template + Status ForEachMutableElementWithStatus(const Fn& func); // Copy the subtree of values from 'other' rooted at ShapeIndex // 'source_base_index' into the subtree of value in this ShapeTree rooted at @@ -161,10 +199,12 @@ class ShapeTree { // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). - static Status ForEachHelper(const StatusVisitorFunction& func, - const Node& node, ShapeIndex* index); - static Status ForEachMutableHelper(const MutableStatusVisitorFunction& func, - Node* node, ShapeIndex* index); + template + static Status ForEachHelper(const Fn& func, const Node& node, + ShapeIndex* index); + template + static Status ForEachMutableHelper(const Fn& func, Node* node, + ShapeIndex* index); // Return the tree node at the given index. Node* Lookup(const ShapeIndex& index); @@ -173,8 +213,13 @@ class ShapeTree { // The root node, which contains all other nodes. Node root_; - // The XLA shape mirrored in this ShapeTree. - Shape shape_; + // If we own our Shape, this field contains it, and shape_ is a pointer into + // here. Otherwise if we don't own our shape, this is nullopt. + tensorflow::gtl::optional shape_storage_; + + // The XLA shape mirrored in this ShapeTree. This is either a pointer into + // shape_storage_ or the Shape pointer passed to our constructor. + const Shape* shape_; }; template @@ -200,20 +245,34 @@ void ShapeTree::InitChildren(const Shape& shape, Node* node) { } template -ShapeTree::ShapeTree(const Shape& shape) : root_(), shape_(shape) { +ShapeTree::ShapeTree(Shape shape) + : root_(), shape_storage_(std::move(shape)), shape_(&*shape_storage_) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(&shape_); - InitChildren(shape_, &root_); + LayoutUtil::ClearLayout(&*shape_storage_); + InitChildren(*shape_, &root_); } template -ShapeTree::ShapeTree(const Shape& shape, const T& init_value) - : root_(init_value), shape_(shape) { +ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { + InitChildren(*shape_, &root_); +} + +template +ShapeTree::ShapeTree(Shape shape, const T& init_value) + : root_(init_value), + shape_storage_(std::move(shape)), + shape_(&*shape_storage_) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(&shape_); - InitChildren(shape_, init_value, &root_); + LayoutUtil::ClearLayout(&*shape_storage_); + InitChildren(*shape_, init_value, &root_); +} + +template +ShapeTree::ShapeTree(const Shape* shape, const T& init_value) + : root_(init_value), shape_(shape) { + InitChildren(*shape_, init_value, &root_); } template @@ -245,8 +304,9 @@ const internal::ShapeTreeNode* ShapeTree::Lookup( /* static */ template -Status ShapeTree::ForEachHelper(const StatusVisitorFunction& func, - const Node& node, ShapeIndex* index) { +template +Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, + ShapeIndex* index) { TF_RETURN_IF_ERROR(func(*index, node.data)); for (int64 i = 0; i < node.children.size(); ++i) { index->push_back(i); @@ -258,8 +318,9 @@ Status ShapeTree::ForEachHelper(const StatusVisitorFunction& func, /* static */ template -Status ShapeTree::ForEachMutableHelper( - const MutableStatusVisitorFunction& func, Node* node, ShapeIndex* index) { +template +Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, + ShapeIndex* index) { TF_RETURN_IF_ERROR(func(*index, &node->data)); for (int64 i = 0; i < node->children.size(); ++i) { index->push_back(i); @@ -271,21 +332,22 @@ Status ShapeTree::ForEachMutableHelper( } template -Status ShapeTree::ForEachElementWithStatus( - const StatusVisitorFunction& func) const { +template +Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { ShapeIndex index; return ForEachHelper(func, root_, &index); } template -Status ShapeTree::ForEachMutableElementWithStatus( - const MutableStatusVisitorFunction& func) { +template +Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { ShapeIndex index; return ForEachMutableHelper(func, &root_, &index); } template -void ShapeTree::ForEachElement(const VisitorFunction& func) const { +template +void ShapeTree::ForEachElement(const Fn& func) const { ShapeIndex index; return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { @@ -297,7 +359,8 @@ void ShapeTree::ForEachElement(const VisitorFunction& func) const { } template -void ShapeTree::ForEachMutableElement(const MutableVisitorFunction& func) { +template +void ShapeTree::ForEachMutableElement(const Fn& func) { ShapeIndex index; return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index afc3a2b2a34777780ec66d2325011390879fe693..3a5db1b3a651e2d353741c6bf4f6962da4e54ba1 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -365,5 +365,31 @@ TEST_F(ShapeTreeTest, OperatorEquals) { } } +TEST_F(ShapeTreeTest, ConstructWithPointerToShape) { + // Construct a ShapeTree using a pointer to a shape, rather than a reference + // to a shape. This constructor is an optimization to let us avoid + // constructing and destroying temporary shapes when we have many ShapeTrees. + ShapeTree t(&nested_tuple_shape_, 42); + int num_nodes = 0; + t.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { + EXPECT_EQ(42, data); + ++num_nodes; + }); + EXPECT_EQ(10, num_nodes); +} + +TEST_F(ShapeTreeTest, CopyWithPointerToShape) { + ShapeTree source(&nested_tuple_shape_, 0); + ShapeTree dest(source); + EXPECT_EQ(&dest.shape(), &nested_tuple_shape_); +} + +TEST_F(ShapeTreeTest, CopyAssignWithPointerToShape) { + ShapeTree source(&nested_tuple_shape_, 0); + ShapeTree dest; + dest = source; + EXPECT_EQ(&dest.shape(), &nested_tuple_shape_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ee49a9ae5f5ff442284f2c4bd620425f815fb08d..057905a4311edc246eeea55019821e834605ae78 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -105,6 +105,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return equal; } +/* static */ int64 ShapeUtil::Rank(const Shape& shape) { + CHECK(!ShapeUtil::IsTuple(shape)) << "Tuples do not have a rank"; + return shape.dimensions_size(); +} + /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -165,6 +170,17 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims); } + +/* static */ Shape ShapeUtil::ShapeWithoutPadding(const Shape& shape) { + Shape result = shape; + ForEachMutableSubshape(&result, [](Shape* subshape, const ShapeIndex& index) { + auto layout = subshape->mutable_layout(); + layout->clear_padding_value(); + layout->clear_padded_dimensions(); + }); + return result; +} + /* static */ void ShapeUtil::PopulateShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, Shape* shape) { @@ -270,7 +286,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsEmptyTuple(shape) || HasZeroElements(shape); + return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape); } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { @@ -323,6 +339,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { + CHECK(!IsTuple(shape)); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -534,11 +551,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == TUPLE) { - // Tuple shape. - if (Rank(shape) != 0) { - return InvalidArgument("tuples must be rank-0; got rank %lld", - Rank(shape)); - } if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 853be6b4cb81881f3f03dbb119dee533aa27634f..fa34bfc951d58d252b4381e10a01b39698eb9015 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -93,6 +93,7 @@ class ShapeUtil { public: // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. + // Precondition: !IsTuple(shape) static int64 ElementsIn(const Shape& shape); // Returns true if 'shape' has zero elements. @@ -144,7 +145,8 @@ class ShapeUtil { static bool Equal(const Shape& lhs, const Shape& rhs); // Returns the rank (number of dimensions) of the given shape. - static int64 Rank(const Shape& shape) { return shape.dimensions_size(); } + // Precondition: !IsTuple(shape) + static int64 Rank(const Shape& shape); // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just @@ -220,6 +222,9 @@ class ShapeUtil { // elements with a different shape. static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape); + // Returns a new shape that has all padding values cleared. + static Shape ShapeWithoutPadding(const Shape& shape); + // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h index aa12cda666c4abfbf7ec38f0aa640df3b51ea106..5e5550563d02de99ddefbeb8ee8e1bf98afdcdbf 100644 --- a/tensorflow/compiler/xla/status_macros.h +++ b/tensorflow/compiler/xla/status_macros.h @@ -183,15 +183,15 @@ class StatusAdaptorForMacros { .with_log_stack_trace() \ .add_ret_check_failure(#condition) -#define TF_ASSIGN_OR_ASSERT_OK(lhs, rexpr) \ - TF_ASSIGN_OR_ASSERT_OK_IMPL( \ +#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ + TF_ASSERT_OK_AND_ASSIGN_IMPL( \ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ rexpr); -#define TF_ASSIGN_OR_ASSERT_OK_IMPL(statusor, lhs, rexpr) \ +#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ auto statusor = (rexpr); \ ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ - lhs = statusor.ConsumeValueOrDie() + lhs = std::move(statusor.ValueOrDie()) #define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) #define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y diff --git a/tensorflow/compiler/xla/status_macros_test.cc b/tensorflow/compiler/xla/status_macros_test.cc index dead17cdfa1e9f19e0ecfbc071e74e159ae82b5f..4b0740dad72f5d96e5ae153abf9232553ff834c2 100644 --- a/tensorflow/compiler/xla/status_macros_test.cc +++ b/tensorflow/compiler/xla/status_macros_test.cc @@ -63,7 +63,7 @@ StatusOr CreateIntUnsuccessfully() { } TEST(StatusMacros, AssignOrAssertOnOK) { - TF_ASSIGN_OR_ASSERT_OK(int result, CreateIntSuccessfully()); + TF_ASSERT_OK_AND_ASSIGN(int result, CreateIntSuccessfully()); EXPECT_EQ(42, result); } diff --git a/tensorflow/compiler/xla/statusor.cc b/tensorflow/compiler/xla/statusor.cc index 36f08fc99f45a7c82f086d04fa60014343d574da..72ab67ff810e0ec384a22da092363cc7446435bb 100644 --- a/tensorflow/compiler/xla/statusor.cc +++ b/tensorflow/compiler/xla/statusor.cc @@ -19,28 +19,20 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { -namespace internal { +namespace internal_statusor { -Status StatusOrHelper::HandleInvalidStatusCtorArg() { +void Helper::HandleInvalidStatusCtorArg(Status* status) { const char* kMessage = - "Status::OK is not a valid constructor argument to StatusOr"; + "An OK status is not a valid constructor argument to StatusOr"; LOG(ERROR) << kMessage; - // In optimized builds, we will fall back to tensorflow::error::INTERNAL. - return Status(tensorflow::error::INTERNAL, kMessage); + // Fall back to tensorflow::error::INTERNAL. + *status = ::tensorflow::errors::Internal(kMessage); } -Status StatusOrHelper::HandleNullObjectCtorArg() { - const char* kMessage = - "NULL is not a valid constructor argument to StatusOr"; - LOG(ERROR) << kMessage; - // In optimized builds, we will fall back to tensorflow::error::INTERNAL. - return Status(tensorflow::error::INTERNAL, kMessage); -} - -void StatusOrHelper::Crash(const Status& status) { +void Helper::Crash(const Status& status) { LOG(FATAL) << "Attempting to fetch value instead of handling error " << status; } -} // namespace internal +} // namespace internal_statusor } // namespace xla diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index d8cd736238c19cc00d0302daa54fc7417740001a..92bcfa0f44d524c1652ec3d2493a3ebb48b95423 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -72,216 +72,233 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_STATUSOR_H_ #include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor_internals.h" #include "tensorflow/core/platform/macros.h" namespace xla { #if defined(__clang__) // Only clang supports warn_unused_result as a type annotation. -template +template class TF_MUST_USE_RESULT StatusOr; #endif -template ::value> -class StatusOr { - template +template +class StatusOr : private internal_statusor::StatusOrData, + private internal_statusor::TraitsBase< + std::is_copy_constructible::value, + std::is_move_constructible::value> { + template friend class StatusOr; + typedef internal_statusor::StatusOrData Base; + public: typedef T element_type; - // Construct a new StatusOr with Status::UNKNOWN status - StatusOr(); + // Constructs a new StatusOr with Status::UNKNOWN status. This is marked + // 'explicit' to try to catch cases like 'return {};', where people think + // StatusOr> will be initialized with an empty vector, + // instead of a Status::UNKNOWN status. + explicit StatusOr(); + + // StatusOr will be copy constructuble/assignable if T is copy + // constructible. + StatusOr(const StatusOr&) = default; + StatusOr& operator=(const StatusOr&) = default; + + // StatusOr will be move constructuble/assignable if T is move + // constructible. + StatusOr(StatusOr&&) = default; + StatusOr& operator=(StatusOr&&) = default; + + // Conversion copy/move constructor, T must be convertible from U. + // TODO(b/62186717): These should not participate in overload resolution if U + // is not convertible to T. + template + StatusOr(const StatusOr& other); + template + StatusOr(StatusOr&& other); - // Construct a new StatusOr with the given non-ok status. After calling - // this constructor, calls to ValueOrDie() will CHECK-fail. - // - // NOTE: Not explicit - we want to use StatusOr as a return - // value, so it is convenient and sensible to be able to do 'return - // Status()' when the return type is StatusOr. - // - // REQUIRES: status != Status::OK. This requirement is DCHECKed. - // In optimized builds, passing Status::OK here will have the effect - // of passing tensorflow::error::INTERNAL as a fallback. - StatusOr(Status status); // NOLINT + // Conversion copy/move assignment operator, T must be convertible from U. + template + StatusOr& operator=(const StatusOr& other); + template + StatusOr& operator=(StatusOr&& other); - // Construct a new StatusOr with the given value. If T is a plain pointer, - // value must not be NULL. After calling this constructor, calls to - // ValueOrDie() will succeed, and calls to status() will return OK. + // Constructs a new StatusOr with the given value. After calling this + // constructor, calls to ValueOrDie() will succeed, and calls to status() will + // return OK. // // NOTE: Not explicit - we want to use StatusOr as a return type // so it is convenient and sensible to be able to do 'return T()' // when the return type is StatusOr. // - // REQUIRES: if T is a plain pointer, value != NULL. This requirement is - // DCHECKed. In optimized builds, passing a NULL pointer here will have - // the effect of passing tensorflow::error::INTERNAL as a fallback. - StatusOr(const T& value); // NOLINT - - // Copy constructor. - StatusOr(const StatusOr& other) = default; - - // Conversion copy constructor, T must be copy constructible from U - template - StatusOr(const StatusOr& other); - - // Assignment operator. - StatusOr& operator=(const StatusOr& other) = default; + // REQUIRES: T is copy constructible. + StatusOr(const T& value); - // Conversion assignment operator, T must be assignable from U - template - StatusOr& operator=(const StatusOr& other); + // Constructs a new StatusOr with the given non-ok status. After calling + // this constructor, calls to ValueOrDie() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr. + // + // REQUIRES: !status.ok(). This requirement is DCHECKed. + // In optimized builds, passing Status::OK() here will have the effect + // of passing tensorflow::error::INTERNAL as a fallback. + StatusOr(const Status& status); + StatusOr& operator=(const Status& status); - // Move constructor and move-assignment operator. - StatusOr(StatusOr&& other) = default; - StatusOr& operator=(StatusOr&& other) = default; + // TODO(b/62186997): Add operator=(T) overloads. - // Rvalue-reference overloads of the other constructors and assignment - // operators, to support move-only types and avoid unnecessary copying. + // Similar to the `const T&` overload. // - // Implementation note: we could avoid all these rvalue-reference overloads - // if the existing lvalue-reference overloads took their arguments by value - // instead. I think this would also let us omit the conversion assignment - // operator altogether, since we'd get the same functionality for free - // from the implicit conversion constructor and ordinary assignment. - // However, this could result in extra copy operations unless we use - // std::move to avoid them, and we can't use std::move because this code - // needs to be portable to C++03. - StatusOr(T&& value); // NOLINT - template - StatusOr(StatusOr&& other); + // REQUIRES: T is move constructible. + StatusOr(T&& value); - // Returns a reference to our status. If this contains a T, then - // returns Status::OK. - const Status& status() const { return status_; } + // RValue versions of the operations declared above. + StatusOr(Status&& status); + StatusOr& operator=(Status&& status); // Returns this->status().ok() - bool ok() const { return status_.ok(); } + bool ok() const { return this->status_.ok(); } + + // Returns a reference to our status. If this contains a T, then + // returns Status::OK(). + const Status& status() const &; + Status status() &&; // Returns a reference to our current value, or CHECK-fails if !this->ok(). - const T& ValueOrDie() const; - T& ValueOrDie(); + // + // Note: for value types that are cheap to copy, prefer simple code: + // + // T value = statusor.ValueOrDie(); + // + // Otherwise, if the value type is expensive to copy, but can be left + // in the StatusOr, simply assign to a reference: + // + // T& value = statusor.ValueOrDie(); // or `const T&` + // + // Otherwise, if the value type supports an efficient move, it can be + // used as follows: + // + // T value = std::move(statusor).ValueOrDie(); + // + // The std::move on statusor instead of on the whole expression enables + // warnings about possible uses of the statusor object after the move. + // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 + // See go/ref-qualifiers for more details on such overloads. + const T& ValueOrDie() const &; + T& ValueOrDie() &; + const T&& ValueOrDie() const &&; + T&& ValueOrDie() &&; - // Moves our current value out of this object and returns it, or CHECK-fails - // if !this->ok(). - // Use of this method is discouraged; prefer std::move(statusor.ValueOrDie()) - // instead. T ConsumeValueOrDie() { return std::move(ValueOrDie()); } - private: - Status status_; - T value_; -}; - -// Partial specialization for when T is not copy-constructible. This uses all -// methods from the core implementation, but removes copy assignment and copy -// construction. -template -class StatusOr : public StatusOr { - public: - // Remove copies. - StatusOr(const StatusOr& other) = delete; - StatusOr& operator=(const StatusOr& other) = delete; - template - StatusOr(const StatusOr& other) = delete; - StatusOr(const T& value) = delete; - - // Use the superclass version for other constructors and operators. - StatusOr() = default; - StatusOr(StatusOr&& other) = default; - StatusOr& operator=(StatusOr&& other) = default; - StatusOr(T&& value) // NOLINT - : StatusOr::StatusOr(std::move(value)) {} - StatusOr(Status status) // NOLINT - : StatusOr::StatusOr(std::move(status)) {} - template - StatusOr(StatusOr&& other) // NOLINT - : StatusOr::StatusOr(std::move(other)) {} + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; }; //////////////////////////////////////////////////////////////////////////////// // Implementation details for StatusOr -namespace internal { +template +StatusOr::StatusOr() : Base(Status(tensorflow::error::UNKNOWN, "")) {} -class StatusOrHelper { - public: - // Move type-agnostic error handling to the .cc. - static Status HandleInvalidStatusCtorArg(); - static Status HandleNullObjectCtorArg(); - static void Crash(const Status& status); - - // Customized behavior for StatusOr vs. StatusOr - template - struct Specialize; -}; +template +StatusOr::StatusOr(const T& value) : Base(value) {} template -struct StatusOrHelper::Specialize { - // For non-pointer T, a reference can never be NULL. - static inline bool IsValueNull(const T& t) { return false; } -}; +StatusOr::StatusOr(const Status& status) : Base(status) {} template -struct StatusOrHelper::Specialize { - static inline bool IsValueNull(const T* t) { return t == NULL; } -}; +StatusOr& StatusOr::operator=(const Status& status) { + this->Assign(status); + return *this; +} -} // namespace internal +template +StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} -template -inline StatusOr::StatusOr() - : status_(tensorflow::error::UNKNOWN, "") {} +template +StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} -template -inline StatusOr::StatusOr(Status status) - : status_(std::move(status)) { - if (status_.ok()) { - status_ = internal::StatusOrHelper::HandleInvalidStatusCtorArg(); - } +template +StatusOr& StatusOr::operator=(Status&& status) { + this->Assign(std::move(status)); + return *this; } -template -inline StatusOr::StatusOr(const T& value) - : value_(value) { - if (internal::StatusOrHelper::Specialize::IsValueNull(value)) { - status_ = internal::StatusOrHelper::HandleNullObjectCtorArg(); - } -} +template +template +inline StatusOr::StatusOr(const StatusOr& other) + : Base(static_cast::Base&>(other)) {} -template +template template -inline StatusOr::StatusOr(const StatusOr& other) - : status_(other.status_), value_(other.value_) {} - -template -inline StatusOr::StatusOr(T&& value) - : value_(std::move(value)) { - if (internal::StatusOrHelper::Specialize::IsValueNull(value_)) { - status_ = internal::StatusOrHelper::HandleNullObjectCtorArg(); - } +inline StatusOr& StatusOr::operator=(const StatusOr& other) { + if (other.ok()) + this->Assign(other.ValueOrDie()); + else + this->Assign(other.status()); + return *this; } -template +template template -inline StatusOr::StatusOr(StatusOr&& other) - : status_(std::move(other.status_)), value_(std::move(other.value_)) {} +inline StatusOr::StatusOr(StatusOr&& other) + : Base(static_cast::Base&&>(other)) {} -template -inline const T& StatusOr::ValueOrDie() const { - if (!ok()) { - internal::StatusOrHelper::Crash(status()); +template +template +inline StatusOr& StatusOr::operator=(StatusOr&& other) { + if (other.ok()) { + this->Assign(std::move(other).ValueOrDie()); + } else { + this->Assign(std::move(other).status()); } - return value_; + return *this; } -template -inline T& StatusOr::ValueOrDie() { - if (!status_.ok()) { - internal::StatusOrHelper::Crash(status()); - } - return value_; +template +const Status& StatusOr::status() const & { + return this->status_; +} +template +Status StatusOr::status() && { + return ok() ? Status::OK() : std::move(this->status_); +} + +template +const T& StatusOr::ValueOrDie() const & { + this->EnsureOk(); + return this->data_; +} + +template +T& StatusOr::ValueOrDie() & { + this->EnsureOk(); + return this->data_; +} + +template +const T&& StatusOr::ValueOrDie() const && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +T&& StatusOr::ValueOrDie() && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +void StatusOr::IgnoreError() const { + // no-op } } // namespace xla diff --git a/tensorflow/compiler/xla/statusor_internals.h b/tensorflow/compiler/xla/statusor_internals.h new file mode 100644 index 0000000000000000000000000000000000000000..a2fda5bb3c6f11c20fc45c57885b1ce7523db81d --- /dev/null +++ b/tensorflow/compiler/xla/statusor_internals.h @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace internal_statusor { + +class Helper { + public: + // Move type-agnostic error handling to the .cc. + static void HandleInvalidStatusCtorArg(Status*); + TF_ATTRIBUTE_NORETURN static void Crash(const Status& status); +}; + +// Construct an instance of T in `p` through placement new, passing Args... to +// the constructor. +// This abstraction is here mostly for the gcc performance fix. +template +void PlacementNew(void* p, Args&&... args) { +#if defined(__GNUC__) && !defined(__clang__) + // Teach gcc that 'p' cannot be null, fixing code size issues. + if (p == nullptr) __builtin_unreachable(); +#endif + new (p) T(std::forward(args)...); +} + +// Helper base class to hold the data and all operations. +// We move all this to a base class to allow mixing with the appropriate +// TraitsBase specialization. +template +class StatusOrData { + template + friend class StatusOrData; + + public: + StatusOrData() = delete; + + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + StatusOrData(StatusOrData&& other) noexcept { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + template + StatusOrData(StatusOrData&& other) { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } + explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } + + explicit StatusOrData(const Status& status) : status_(status) { + EnsureNotOk(); + } + explicit StatusOrData(Status&& status) : status_(std::move(status)) { + EnsureNotOk(); + } + + StatusOrData& operator=(const StatusOrData& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(other.data_); + else + Assign(other.status_); + return *this; + } + + StatusOrData& operator=(StatusOrData&& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(std::move(other.data_)); + else + Assign(std::move(other.status_)); + return *this; + } + + ~StatusOrData() { + if (ok()) { + status_.~Status(); + data_.~T(); + } else { + status_.~Status(); + } + } + + void Assign(const T& value) { + if (ok()) { + data_.~T(); + MakeValue(value); + } else { + MakeValue(value); + status_ = Status::OK(); + } + } + + void Assign(T&& value) { + if (ok()) { + data_.~T(); + MakeValue(std::move(value)); + } else { + MakeValue(std::move(value)); + status_ = Status::OK(); + } + } + + void Assign(const Status& status) { + Clear(); + status_ = status; + EnsureNotOk(); + } + + void Assign(Status&& status) { + Clear(); + status_ = std::move(status); + EnsureNotOk(); + } + + bool ok() const { return status_.ok(); } + + protected: + // status_ will always be active after the constructor. + // We make it a union to be able to initialize exactly how we need without + // waste. + // Eg. in the copy constructor we use the default constructor of Status in + // the ok() path to avoid an extra Ref call. + union { + Status status_; + }; + + // data_ is active iff status_.ok()==true + struct Dummy {}; + union { + // When T is const, we need some non-const object we can cast to void* for + // the placement new. dummy_ is that object. + Dummy dummy_; + T data_; + }; + + void Clear() { + if (ok()) data_.~T(); + } + + void EnsureOk() const { + if (!ok()) Helper::Crash(status_); + } + + void EnsureNotOk() { + if (ok()) Helper::HandleInvalidStatusCtorArg(&status_); + } + + // Construct the value (ie. data_) through placement new with the passed + // argument. + template + void MakeValue(Arg&& arg) { + internal_statusor::PlacementNew(&dummy_, std::forward(arg)); + } + + // Construct the status (ie. status_) through placement new with the passed + // argument. + template + void MakeStatus(Args&&... args) { + internal_statusor::PlacementNew(&status_, + std::forward(args)...); + } +}; + +// Helper base class to allow implicitly deleted constructors and assignment +// operations in StatusOr. +// TraitsBase will explicitly delete what it can't support and StatusOr will +// inherit that behavior implicitly. +template +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = default; + TraitsBase(TraitsBase&&) = default; + TraitsBase& operator=(const TraitsBase&) = default; + TraitsBase& operator=(TraitsBase&&) = default; +}; + +template <> +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = delete; + TraitsBase(TraitsBase&&) = default; + TraitsBase& operator=(const TraitsBase&) = delete; + TraitsBase& operator=(TraitsBase&&) = default; +}; + +template <> +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = delete; + TraitsBase(TraitsBase&&) = delete; + TraitsBase& operator=(const TraitsBase&) = delete; + TraitsBase& operator=(TraitsBase&&) = delete; +}; + +} // namespace internal_statusor +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f8555113f816d933423bdf38741d18a574ddd9ce..5fa2211ac66177514ac8ecabfa8791e7c8c014a2 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -29,8 +29,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::Status; - class Base1 { public: virtual ~Base1() {} @@ -59,6 +57,14 @@ class CopyNoAssign { const CopyNoAssign& operator=(const CopyNoAssign&); }; +class NoDefaultConstructor { + public: + explicit NoDefaultConstructor(int foo); +}; + +static_assert(!std::is_default_constructible(), + "Should not be default-constructible."); + StatusOr> ReturnUniquePtr() { // Uses implicit constructor from T&& return std::unique_ptr(new int(0)); @@ -69,6 +75,18 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same::element_type, char>(), ""); } +TEST(StatusOr, TestNoDefaultConstructorInitialization) { + // Explicitly initialize it with an error code. + StatusOr statusor(tensorflow::errors::Cancelled("")); + EXPECT_FALSE(statusor.ok()); + EXPECT_EQ(statusor.status().code(), tensorflow::error::CANCELLED); + + // Default construction of StatusOr initializes it with an UNKNOWN error code. + StatusOr statusor2; + EXPECT_FALSE(statusor2.ok()); + EXPECT_EQ(statusor2.status().code(), tensorflow::error::UNKNOWN); +} + TEST(StatusOr, TestMoveOnlyInitialization) { StatusOr> thing(ReturnUniquePtr()); ASSERT_TRUE(thing.ok()); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 13dd1a30b60a64171425f2a7d872da9bb2ca5380..a94ff9db899c52f1005cdae84ede2209467bcb8f 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -26,6 +26,7 @@ filegroup( load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") @@ -94,11 +95,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", @@ -116,8 +117,12 @@ cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", ], ) @@ -139,6 +144,7 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -151,7 +157,6 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -171,6 +176,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", @@ -196,12 +202,14 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", @@ -213,12 +221,13 @@ xla_test( srcs = ["bad_rng_shape_validation_test.cc"], deps = [ "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", @@ -233,12 +242,12 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -255,7 +264,6 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -268,6 +276,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", @@ -275,7 +284,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -291,7 +299,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -307,6 +314,7 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", @@ -315,7 +323,6 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -339,7 +346,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -356,7 +362,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -371,7 +376,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -388,7 +393,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -409,7 +414,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -422,12 +427,13 @@ xla_test( srcs = ["deallocation_test.cc"], deps = [ "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -441,13 +447,14 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -471,9 +478,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -481,36 +486,29 @@ xla_test( ) xla_test( - name = "dot_operation_test", - srcs = ["dot_operation_test.cc"], + name = "reduce_precision_test", + srcs = ["reduce_precision_test.cc"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/core:test", ], ) -# Tests the dot operation in some cases that can be performed via a -# runtime call on some backends - e.g. a runtime call to to Eigen. xla_test( - name = "dot_operation_runtime_test", + name = "dot_operation_test", srcs = ["dot_operation_test.cc"], - backend_args = { - "cpu": ["--xla_cpu_use_eigen"], - "cpu_parallel": ["--xla_cpu_use_eigen"], - }, deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -518,9 +516,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -530,20 +526,11 @@ xla_test( ], ) -# Repeat dot_operation_runtime_test with single-threded eigen. +# Tests the dot operation in some cases that can be performed via a +# runtime call on some backends - e.g. a runtime call to Eigen. xla_test( - name = "dot_operation_single_threaded_runtime_test", + name = "dot_operation_runtime_test", srcs = ["dot_operation_test.cc"], - backend_args = { - "cpu": [ - "--xla_cpu_use_eigen", - "--xla_cpu_multi_thread_eigen=false", - ], - "cpu_parallel": [ - "--xla_cpu_use_eigen", - "--xla_cpu_multi_thread_eigen=false", - ], - }, deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -551,9 +538,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -563,17 +548,16 @@ xla_test( ], ) +# Repeat dot_operation_runtime_test with single-threded eigen. xla_test( - name = "dot_operation_rowmajor_runtime_test", + name = "dot_operation_single_threaded_runtime_test", srcs = ["dot_operation_test.cc"], backend_args = { "cpu": [ - "--xla_cpu_use_eigen", - "--xla_default_layout=major2minor", + "--xla_cpu_multi_thread_eigen=false", ], "cpu_parallel": [ - "--xla_cpu_use_eigen", - "--xla_default_layout=major2minor", + "--xla_cpu_multi_thread_eigen=false", ], }, deps = [ @@ -583,9 +567,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -605,7 +587,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -624,7 +606,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -650,7 +632,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -677,7 +659,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -694,12 +676,13 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -710,19 +693,29 @@ xla_test( xla_test( name = "batch_normalization_test", srcs = ["batch_normalization_test.cc"], + shard_count = 40, deps = [ + ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -738,7 +731,7 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -754,7 +747,7 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -768,11 +761,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -799,7 +794,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -816,7 +811,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -842,7 +837,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -850,11 +845,14 @@ xla_test( ], ) -xla_test( - name = "reduce_window_test", - timeout = "long", +# External xla_test targets can add "reduce_window_test_library" to xla_test_library_deps, in order +# to refer to the cc_library compiled with the correct backend macros. The following test target +# "reduce_window_test" is an example. +xla_test_library( + name = "reduce_window_test_library", srcs = ["reduce_window_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", @@ -865,7 +863,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -873,6 +871,14 @@ xla_test( ], ) +xla_test( + name = "reduce_window_test", + timeout = "long", + srcs = [], + xla_test_library_deps = [":reduce_window_test_library"], + deps = [], +) + xla_test( name = "select_and_scatter_test", timeout = "long", @@ -889,7 +895,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -906,7 +912,7 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -921,10 +927,11 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -941,11 +948,12 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -958,7 +966,7 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -976,8 +984,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", ], @@ -995,7 +1002,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1009,7 +1016,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1022,7 +1029,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1044,7 +1051,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1058,11 +1065,12 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1081,13 +1089,14 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1103,7 +1112,7 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1125,7 +1134,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1142,11 +1151,12 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1161,7 +1171,6 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1184,7 +1193,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1199,7 +1208,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1215,13 +1224,14 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -1240,7 +1250,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1262,7 +1272,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1279,7 +1289,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1298,7 +1308,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1315,8 +1325,14 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1324,6 +1340,31 @@ xla_test( ], ) +xla_test( + name = "multioutput_fusion_test", + srcs = ["multioutput_fusion_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_test( name = "local_client_aot_test", srcs = [ @@ -1333,6 +1374,7 @@ cc_test( linkstatic = 1, deps = [ "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -1347,9 +1389,9 @@ cc_test( ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:local_service", - "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) @@ -1365,7 +1407,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1381,7 +1423,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1407,7 +1449,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1415,6 +1457,15 @@ xla_test( ], ) +xla_test( + name = "deep_graph_test", + srcs = ["deep_graph_test.cc"], + deps = [ + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + ], +) + cc_test( name = "literal_test_util_test", srcs = ["literal_test_util_test.cc"], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c07f2745fe9e67898148bf0026ac32534eac506c..fb913e200ffa2ea64cb4014fe3d62efafcfb2bfa 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -26,9 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -45,7 +43,7 @@ namespace { class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: - ErrorSpec error_spec_{0.0001}; + ErrorSpec error_spec_{0.0001, 0.0001}; }; class ArrayElementwiseOpTestParamCount @@ -158,13 +156,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + std::unique_ptr a_literal = Literal::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a_constant = builder.ConstantR1(a_values); auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); - std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + std::unique_ptr b_literal = Literal::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); @@ -804,7 +802,7 @@ TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + std::unique_ptr param_literal = Literal::CreateR1(values); std::unique_ptr param_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -826,6 +824,244 @@ TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { ComputeAndCompareR1(&b, expected, {param_data.get()}, error_spec_); } +TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; + std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + b.Pow(b.Exp(param0), param1); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = std::pow(std::exp(values0[i]), values1[i]); + } + + ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; + std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + b.Log(b.Pow(param0, param1)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = std::log(std::pow(values0[i], values1[i])); + } + + ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; + std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + b.Mul(b.Exp(param0), b.Exp(param1)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = std::exp(values0[i]) * std::exp(values1[i]); + } + + ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; + std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + b.Div(param0, b.Exp(param1)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = values0[i] / std::exp(values1[i]); + } + + ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; + std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr data2 = + client_->TransferToServer(*literal2).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + auto param2 = b.Parameter(2, literal2->shape(), "param2"); + b.Div(b.Div(param0, param1), param2); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = (values0[i] / values1[i]) / values2[i]; + } + + ComputeAndCompareR1( + &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; + std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr data2 = + client_->TransferToServer(*literal2).ConsumeValueOrDie(); + + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + auto param2 = b.Parameter(2, literal2->shape(), "param2"); + b.Div(param0, b.Div(param1, param2)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = values0[i] / (values1[i] / values2[i]); + } + + ComputeAndCompareR1( + &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; + std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; + std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr data2 = + client_->TransferToServer(*literal2).ConsumeValueOrDie(); + + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + auto param2 = b.Parameter(2, literal2->shape(), "param2"); + b.Div(param0, b.Pow(param1, param2)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = values0[i] / std::pow(values1[i], values2[i]); + } + + ComputeAndCompareR1( + &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Div4F32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; + std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; + std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr data2 = + client_->TransferToServer(*literal2).ConsumeValueOrDie(); + + std::unique_ptr literal3 = Literal::CreateR1(values3); + std::unique_ptr data3 = + client_->TransferToServer(*literal3).ConsumeValueOrDie(); + + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + auto param2 = b.Parameter(2, literal2->shape(), "param2"); + auto param3 = b.Parameter(3, literal3->shape(), "param2"); + b.Div(b.Div(param0, param1), b.Div(param2, param3)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]); + } + + ComputeAndCompareR1( + &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()}, + error_spec_); +} + TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); @@ -1241,12 +1477,12 @@ TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); + Literal::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -1263,12 +1499,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); + Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); + Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -1285,7 +1521,7 @@ TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -1297,6 +1533,24 @@ TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { {param0_data.get()}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); + auto result = builder.Cos(a); + + ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); + auto result = builder.Sin(a); + + ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, + error_spec_); +} + TEST_F(ArrayElementwiseOpTest, TanhF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); @@ -1447,9 +1701,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), - LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); + auto expected = Literal::MakeTuple( + {Literal::CreateR2({{true, true}, {true, false}}).get(), + Literal::CreateR2({{true, false}, {false, false}}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -1802,7 +2056,7 @@ TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); ComputationBuilder builder(client_, TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR4FromArray4D(r4); + std::unique_ptr a_literal = Literal::CreateR4FromArray4D(r4); *a_literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); auto a = builder.ConstantLiteral(*a_literal); @@ -1838,8 +2092,8 @@ TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { // broadcast. TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { ComputationBuilder builder(client_, TestName()); - auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); - auto y_literal = LiteralUtil::CreateR1({4, 5}); + auto x_literal = Literal::CreateR1({1, 2, 3}); + auto y_literal = Literal::CreateR1({4, 5}); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); @@ -1862,8 +2116,6 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); - xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index a1ca1de584f8be808d19a43680f7c093d4f94def..67dbc913b42c89bf5a8fb5b91da13a29e5e248f5 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -76,7 +75,6 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index ea58491038c1dfcc8069b3c14833ade554be0d8a..02be0b5ab83c23fda36c5ccc65a598fc8e4a1600 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -70,7 +69,6 @@ TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 6a47f1b718a1734de731ec50d7094ac529eca9df..d692a810325eae1ebe50e1ad84caf279d51666f1 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -23,13 +23,22 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -48,7 +57,7 @@ class BatchNormalizationTest : public ClientLibraryTestBase { {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = *LiteralUtil::CreateR4FromArray4D(input_array_); + input_literal_ = *Literal::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -190,13 +199,422 @@ TEST_F(BatchNormalizationTest, SpecComparisonForward) { ComputeAndCompareR4(&builder, expected, {}, error_spec_); } +struct BatchNormTestParam { + std::vector bounds; + int64 feature_index; + float random_value_mean; + float random_value_var; +}; + +// Tests to test the fused operation of BatchNorm. +class BatchNormTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { +}; + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) { + float epsilon = 0.001; + ComputationBuilder builder(client_, TestName()); + const std::vector& bounds = GetParam().bounds; + Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); + input_array.FillRandom(GetParam().random_value_var, + GetParam().random_value_mean); + + const int64 feature_index = GetParam().feature_index; + const int64 num_elements_per_feature = + Product(bounds) / bounds[feature_index]; + const int64 feature_bound = bounds[feature_index]; + std::vector offset(feature_bound, 1); + std::vector scale(feature_bound, 2); + + auto input_squared = + ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); + std::vector reduce_dims; + for (int64 i = 0; i < bounds.size(); ++i) { + if (i != feature_index) { + reduce_dims.push_back(i); + } + } + + auto sum = + ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + auto sum_squared = + ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + std::vector mean(feature_bound); + + for (int64 i = 0; i < feature_bound; ++i) { + mean[i] = sum[i] / num_elements_per_feature; + } + + std::vector mean_square(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + mean_square[i] = mean[i] * mean[i]; + } + + std::vector square_mean(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + square_mean[i] = sum_squared[i] / num_elements_per_feature; + } + + std::vector var(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + var[i] = square_mean[i] - mean_square[i]; + } + + Array4D mean_4D = + *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); + auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); + auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); + auto offset_4D = + *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); + + auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean_4D, var_4D, + scale_4D, offset_4D, epsilon); + + auto expected_normalized = Literal::CreateR4FromArray4D(normalized); + + auto offset_literal = Literal::CreateR1(offset); + auto scale_literal = Literal::CreateR1(scale); + auto input_literal = Literal::CreateR4FromArray4D(input_array); + + auto input_activations = + builder.Parameter(0, input_literal->shape(), "input"); + auto scale_activations = + builder.Parameter(1, scale_literal->shape(), "offset"); + auto offset_activations = + builder.Parameter(2, offset_literal->shape(), "scale"); + + auto expected = *Literal::MakeTuple({expected_normalized.get(), + Literal::CreateR1(mean).get(), + Literal::CreateR1(var).get()}); + + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + std::unique_ptr scale_data = + client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + std::unique_ptr offset_data = + client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + + builder.BatchNormTraining(input_activations, scale_activations, + offset_activations, epsilon, feature_index); + + ComputeAndCompareTuple( + &builder, expected, + {input_data.get(), scale_data.get(), offset_data.get()}, + ErrorSpec(0.01, 1)); +} + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU( + DISABLED_ON_GPU(RandomizedGradTests)))) { + float epsilon = 0.001; + ComputationBuilder builder(client_, TestName()); + const std::vector& bounds = GetParam().bounds; + Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); + input_array.FillRandom(GetParam().random_value_var, + GetParam().random_value_mean); + + Array4D grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]); + grad_output_array.FillRandom(GetParam().random_value_var, + GetParam().random_value_mean); + + const int64 feature_index = GetParam().feature_index; + const int64 num_elements_per_feature = + Product(bounds) / bounds[feature_index]; + const int64 feature_bound = bounds[feature_index]; + std::vector scale(feature_bound, 2); + + auto input_squared = + ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); + std::vector reduce_dims; + for (int64 i = 0; i < bounds.size(); ++i) { + if (i != feature_index) { + reduce_dims.push_back(i); + } + } + + auto sum = + ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + auto sum_squared = + ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + std::vector mean(feature_bound); + + for (int64 i = 0; i < feature_bound; ++i) { + mean[i] = sum[i] / num_elements_per_feature; + } + + std::vector mean_square(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + mean_square[i] = mean[i] * mean[i]; + } + + std::vector square_mean(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + square_mean[i] = sum_squared[i] / num_elements_per_feature; + } + + std::vector var(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + var[i] = square_mean[i] - mean_square[i]; + } + + Array4D mean_4D = + *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); + auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); + auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); + + auto var_add_epsilon = *ReferenceUtil::MapArray4D( + var_4D, [epsilon](float a) { return std::sqrt(a + epsilon); }); + + auto grad_output_times_var = + *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, + [](float a, float b) { return a * b; }); + + auto grad_activation = *ReferenceUtil::MapArray4D( + grad_output_times_var, scale_4D, [](float a, float b) { return a * b; }); + + auto activation_shifted = *ReferenceUtil::MapArray4D( + input_array, mean_4D, [](float a, float b) { return a - b; }); + + auto grad_scale_before_reduction = + *ReferenceUtil::MapArray4D(grad_output_times_var, activation_shifted, + [](float a, float b) { return a * b; }); + + auto grad_scale = ReferenceUtil::Reduce4DTo1D( + grad_scale_before_reduction, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + auto grad_offset = + ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + auto expected_grad_activation = + Literal::CreateR4FromArray4D(grad_activation); + + auto input_literal = Literal::CreateR4FromArray4D(input_array); + auto scale_literal = Literal::CreateR1(scale); + auto mean_literal = Literal::CreateR1(mean); + auto var_literal = Literal::CreateR1(var); + auto grad_output_literal = + Literal::CreateR4FromArray4D(grad_output_array); + + auto input_parameter = builder.Parameter(0, input_literal->shape(), "input"); + auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale"); + auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean"); + auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance"); + auto grad_output_parameter = + builder.Parameter(4, grad_output_literal->shape(), "grad_output"); + + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + std::unique_ptr scale_data = + client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + std::unique_ptr mean_data = + client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + std::unique_ptr var_data = + client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + std::unique_ptr grad_output_data = + client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); + + auto t = builder.BatchNormGrad(input_parameter, scale_parameter, + mean_parameter, var_parameter, + grad_output_parameter, epsilon, feature_index); + + auto expected = + *Literal::MakeTuple({expected_grad_activation.get(), + Literal::CreateR1(grad_scale).get(), + Literal::CreateR1(grad_offset).get()}); + + ComputeAndCompareTuple(&builder, expected, + {input_data.get(), scale_data.get(), mean_data.get(), + var_data.get(), grad_output_data.get()}, + ErrorSpec(0.01, 1)); +} + +INSTANTIATE_TEST_CASE_P( + BatchNormTest_Instantiation, BatchNormTest, + ::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f}, + BatchNormTestParam{{2, 2, 2, 2}, 3, 300.f, 400.0f}, + + BatchNormTestParam{{1, 10, 1, 1}, 0, 10.1f, 20.1f}, + BatchNormTestParam{{10, 10, 10, 10}, 1, 3.14f, 314.15f}, + BatchNormTestParam{{10, 10, 10, 10}, 2, 666.6f, 777.7f}, + BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f}, + BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f}, + BatchNormTestParam{{1, 1, 10, 130}, 2, 0.f, 777.7f}, + BatchNormTestParam{{1, 1, 130, 11}, 2, 0.f, 777.7f}, + BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f}, + + BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000}, + BatchNormTestParam{{24, 129, 1, 2}, 3, 10000, 10000}, + + // Feature on low dimension to trigger relayout, test + // internal logical to physical dimension calculation + // is correct after relayout. + BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100})); + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) { + const int kFeatureIndex = 3; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); + + auto scale = builder.ConstantR1({2.0f, 3.0f}); + + auto offset = builder.ConstantR1({1.0f, 2.0f}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) + .get(), + Literal::CreateR1({4, 5}).get(), + Literal::CreateR1({5, 5}).get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); +} + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); + + auto scale = builder.ConstantR1({2.0f, 3.0f}); + + auto offset = builder.ConstantR1({1.0f, 2.0f}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) + .get(), + Literal::CreateR1({4, 5}).get(), + Literal::CreateR1({5, 5}).get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); +} + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) { + // Use 0 dimension as feature, tests layout analyzer. + const int kFeatureIndex = 0; + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle h0; + auto operand = CreateR3Parameter(Array3D(260, 2, 2, 1.0f), + /*parameter_number=*/0, "operand", + &builder, &h0); + ComputationDataHandle h1; + auto scale = + CreateR1Parameter(std::vector(260, 1.0f), + /*parameter_number=*/1, "scale", &builder, &h1); + ComputationDataHandle h2; + auto offset = + CreateR1Parameter(std::vector(260, 1.0f), + /*parameter_number=*/2, "offset", &builder, &h2); + + auto tuple = builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) + .get(), + Literal::CreateR1(std::vector(260, 1.0f)).get(), + Literal::CreateR1(std::vector(260, 0.0f)).get()}); + + ComputeAndCompareTuple(&builder, expected, + {operand.get(), scale.get(), offset.get()}, + ErrorSpec(0.1)); +} + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) { + // Test the correctness of choosing a large epsilon value. + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle h0; + auto operand = CreateR3Parameter({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, + /*parameter_number=*/0, "operand", + &builder, &h0); + ComputationDataHandle h1; + auto scale = + CreateR1Parameter(std::vector(1, 1.0f), + /*parameter_number=*/1, "scale", &builder, &h1); + ComputationDataHandle h2; + auto offset = + CreateR1Parameter(std::vector(1, 0.0f), + /*parameter_number=*/2, "offset", &builder, &h2); + + // var = 125, mean = 15, epsilon = -100 + auto tuple = builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) + .get(), + Literal::CreateR1(std::vector(1, 15.0f)).get(), + Literal::CreateR1(std::vector(1, 125.0f)).get()}); + + ComputeAndCompareTuple(&builder, expected, + {operand.get(), scale.get(), offset.get()}, + ErrorSpec(0.1)); +} + +// TODO(b/62764704): Implement on CPU and GPU. Disabled on 2017-07-11. +XLA_TEST_F(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU( + DISABLED_ON_GPU(BatchNormGradBasic)))) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = + builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); + + auto scale = builder.ConstantR1({1.0f, 1.0f}); + + auto mean = builder.ConstantR1({0.0f, 0.0f}); + + auto var = builder.ConstantR1({1.0f, 1.0f}); + + auto grad_output = builder.ConstantR4FromArray4D( + {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); + + builder.BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4( + {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}) + .get(), + Literal::CreateR1({0, 0}).get(), + Literal::CreateR1({16, 20}).get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 5e3b70702dd482e6b278386d70fef60b1bacb926..e6b853c2e4e4a08174012c1eb8be3739a2c9dba9 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -143,7 +142,6 @@ TEST_F(BinopScalingTest, R4PlusR0S32) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 25fe04a930e3783ff6024a0bb3bddc430c4fafdd..2a57835ca93cd2b367fe0402aee1f986ae2d4ff3 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -21,9 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -63,9 +61,8 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array3D* r3_array, float start, float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = - LiteralUtil::Relayout(*LiteralUtil::CreateR3FromArray3D(*r3_array), - LayoutUtil::MakeLayout(minor_to_major)); + auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( + LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = client_->TransferToServer(*r3_data).ConsumeValueOrDie(); return r3_global_data; @@ -77,9 +74,8 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array2D* r2_array, float start, float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = - LiteralUtil::Relayout(*LiteralUtil::CreateR2FromArray2D(*r2_array), - LayoutUtil::MakeLayout(minor_to_major)); + auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( + LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = client_->TransferToServer(*r2_data).ConsumeValueOrDie(); return r2_global_data; @@ -217,13 +213,13 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { ComputationBuilder b(client_, TestName()); b.Add(b.ConstantR2({{1.0, 5.0}}), - b.ConstantLiteral(*LiteralUtil::CreateR3( + b.ConstantLiteral(*Literal::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); auto expected = - LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, - {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); + Literal::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, + {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -292,7 +288,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } } } - auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + auto expected = Literal::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r3_implicit_global_data.get(), r3_global_data.get()}, @@ -317,7 +313,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { b.Add(r3h, r1h); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + Literal::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); @@ -325,81 +321,79 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + Literal::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}, {2}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + Literal::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { ComputationBuilder b(client_, TestName()); - auto r1 = - b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + Literal::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { ComputationBuilder b(client_, TestName()); - auto r1 = - b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + Literal::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r1 = + b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + Literal::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + Literal::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -541,7 +535,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { *v = ApplyOpToFloats(spec.op2, tmp, v3); }); - auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + auto expected = Literal::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r2_implicit_global_data1.get(), r2_global_data.get(), @@ -555,22 +549,22 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}})); - auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}})); + auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); - auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); + auto expected = Literal::CreateR2({{2, 4}, {4, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1}, {2}})); - auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1}, {2}})); + auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); - auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); + auto expected = Literal::CreateR2({{2, 3}, {5, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -579,11 +573,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1, {0}); - auto expected = LiteralUtil::CreateR3( - {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); + auto expected = + Literal::CreateR3({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -592,11 +586,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r1, r3, {1}); - auto expected = LiteralUtil::CreateR3( - {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); + auto expected = + Literal::CreateR3({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -605,11 +599,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r1, r3, {2}); - auto expected = LiteralUtil::CreateR3( - {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); + auto expected = + Literal::CreateR3({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -620,7 +614,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = b.ConstantR1({100, 200}); auto r1_2 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = b.Add(r1_0, r3, {0}); r3 = b.Add(r3, r1_1, {1}); @@ -628,7 +622,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { } r3 = b.Mul(r3, b.ConstantR0(-2)); - auto expected = LiteralUtil::CreateR3( + auto expected = Literal::CreateR3( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); @@ -649,7 +643,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { } r3 = b.Mul(r3, b.ConstantR0(-1)); - auto expected = LiteralUtil::CreateR3( + auto expected = Literal::CreateR3( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); @@ -662,7 +656,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { ComputationBuilder b(client_, TestName()); b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), - b.ConstantLiteral(*LiteralUtil::CreateR3( + b.ConstantLiteral(*Literal::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); @@ -704,8 +698,6 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); - xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 96a329a9bd8296a11a3e22e8dea31d71dd973d76..dc1443f5363aab1e6166984a3f2f3fccefad908e 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -39,7 +38,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { // Test degenerate case of broadcasting a scalar into a scalar. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {}), input, {})); @@ -48,14 +47,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0(42.0), *result, + LiteralTestUtil::ExpectNear(*Literal::CreateR0(42.0), *result, error_spec_); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {})); @@ -65,14 +64,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, error_spec_); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple // to enable testing of the results. @@ -88,18 +87,18 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), result->tuple_literals(0), error_spec_); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), result->tuple_literals(1), error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); @@ -109,7 +108,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, error_spec_); } @@ -118,7 +117,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { // the dimensions, ie transpose. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); @@ -128,14 +127,14 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + *Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); @@ -145,15 +144,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0, 2.0}))); + HloInstruction::CreateConstant(Literal::CreateR1({1.0, 2.0}))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -168,8 +167,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -178,7 +177,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { int64 r1_size = input_data.size(); std::iota(input_data.begin(), input_data.end(), 0.0f); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1(input_data))); + HloInstruction::CreateConstant(Literal::CreateR1(input_data))); // Broadcast vector in dimension 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -198,8 +197,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -209,7 +208,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { std::vector r1_array(64, 42.0); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1(r1_array))); + HloInstruction::CreateConstant(Literal::CreateR1(r1_array))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -220,14 +219,14 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR4FromArray4D(r4_array), - *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, + error_spec_); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); @@ -240,15 +239,15 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { auto builder = HloComputation::Builder(TestName()); Array2D to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}}); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(to_broadcast))); + Literal::CreateR2FromArray2D(to_broadcast))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -262,8 +261,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -282,7 +281,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { } } auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR3FromArray3D(input_vals))); + Literal::CreateR3FromArray3D(input_vals))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -293,8 +292,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } } // namespace @@ -302,7 +301,6 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 1f61743451a79a062205708d9ba6014f8a8591e9..dae0956f0a9f3d6fe172dc13b7ce3877a760e161 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -1,16 +1,32 @@ """Build rules for XLA testing.""" load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") +load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") -def all_backends(): +all_backends = ["cpu", "cpu_parallel", "gpu"] + plugins.keys() + +def filter_backends(backends): + """Removes "gpu" from a backend list if CUDA is not enabled. + + This allows us to simply hardcode lists including "gpu" here and in the + BUILD file, without causing failures when CUDA isn't enabled.' + + Args: + backends: A list of backends to filter. + + Returns: + The filtered list of backends. + """ if cuda_is_configured(): - return ["cpu", "cpu_parallel", "gpu"] + return backends else: - return ["cpu", "cpu_parallel"] + return [backend for backend in backends if backend != "gpu"] + def xla_test(name, srcs, deps, + xla_test_library_deps=[], backends=[], args=[], tags=[], @@ -69,6 +85,8 @@ def xla_test(name, name: Name of the target. srcs: Sources for the target. deps: Dependencies of the target. + xla_test_library_deps: If set, the generated test targets will depend on the + respective cc_libraries generated by the xla_test_library rule. backends: A list of backends to generate tests for. Supported values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will be generated for all supported backends. @@ -81,7 +99,7 @@ def xla_test(name, """ test_names = [] if not backends: - backends = all_backends() + backends = all_backends native.cc_library( name="%s_lib" % name, @@ -91,7 +109,7 @@ def xla_test(name, deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], ) - for backend in backends: + for backend in filter_backends(backends): test_name = "%s_%s" % (name, backend) this_backend_tags = ["xla_%s" % backend] this_backend_copts = [] @@ -107,9 +125,18 @@ def xla_test(name, backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] this_backend_tags += ["requires-gpu-sm35"] + elif backend in plugins: + backend_deps = plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + this_backend_tags += plugins[backend]["tags"] + this_backend_args += plugins[backend]["args"] else: fail("Unknown backend %s" % backend) + if xla_test_library_deps: + for lib_dep in xla_test_library_deps: + backend_deps += ["%s_%s" % (lib_dep, backend)] + native.cc_test( name=test_name, srcs=srcs, @@ -124,19 +151,82 @@ def xla_test(name, native.test_suite(name=name, tests=test_names) +def xla_test_library(name, + srcs, + hdrs=[], + deps=[], + backends=[]): + """Generates cc_library targets for the given XLA backends. + + This rule forces the sources to be compiled for each backend so that the + backend specific macros could expand correctly. It's useful when test targets + in different directories referring to the same sources but test with different + arguments. + + Examples: + + # Generates the targets: foo_test_library_cpu and foo_test_gpu. + xla_test_library( + name = "foo_test_library", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + # Then use the xla_test rule to generate test targets: + xla_test( + name = "foo_test", + srcs = [], + backends = ["cpu", "gpu"], + deps = [...], + xla_test_library_deps = [":foo_test_library"], + ) + + Args: + name: Name of the target. + srcs: Sources for the target. + hdrs: Headers for the target. + deps: Dependencies of the target. + backends: A list of backends to generate libraries for. + Supported values: "cpu", "cpu_parallel", "gpu". If this list is empty, the + library will be generated for all supported backends. + """ + + if not backends: + backends = all_backends + + for backend in filter_backends(backends): + this_backend_copts = [] + if backend in ["cpu", "cpu_parallel", "gpu"]: + backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] + elif backend in plugins: + backend_deps = plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + else: + fail("Unknown backend %s" % backend) + + native.cc_library( + name = "%s_%s" % (name, backend), + srcs = srcs, + testonly = True, + hdrs = hdrs, + copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + deps = deps + backend_deps, + ) + def generate_backend_suites(backends=[]): if not backends: - backends = all_backends() - for backend in backends: + backends = all_backends + for backend in filter_backends(backends): native.test_suite(name="%s_tests" % backend, tags = ["xla_%s" % backend]) def generate_backend_test_macros(backends=[]): if not backends: - backends = all_backends() - for backend in backends: + backends = all_backends + for backend in filter_backends(backends): native.cc_library( name="test_macros_%s" % backend, testonly = True, diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 55701c62db22f0fff6f4fdeabf0c72d600239969..086199fda1445c966917cff6849373e4474d16f7 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -78,7 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR0F32IdentityComputation(); - auto constant = builder.ConstantLiteral(*LiteralUtil::CreateR0(42.0)); + auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); builder.Call(callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR1S0F32AdditionComputation(); - auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1({})); - auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1({})); + auto x = builder.ConstantLiteral(*Literal::CreateR1({})); + auto y = builder.ConstantLiteral(*Literal::CreateR1({})); builder.Call(callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -97,8 +96,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR1S2F32AdditionComputation(); - auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1({1.0f, 2.0f})); - auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1({2.0f, 3.0f})); + auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); + auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); builder.Call(callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -107,8 +106,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR0F32TupleComputation(); - auto elem = LiteralUtil::CreateR0(42.0); - auto tuple = LiteralUtil::MakeTuple({elem.get()}); + auto elem = Literal::CreateR0(42.0); + auto tuple = Literal::MakeTuple({elem.get()}); builder.Call(callee, {builder.ConstantLiteral(*elem)}); ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); @@ -120,7 +119,6 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 4825eaf19dc28fd78a5d91a3c1e722c3916f6c20..2f4ad22f5bf0573ba97e6d28a3a207480fcdae18 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -38,7 +37,7 @@ class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { ComputationBuilder builder(client_, "add_two_params"); - auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); + auto param_literal = Literal::CreateR1({1.1f, 2.2f}); auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param_literal->shape(), "param1"); @@ -55,18 +54,20 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { // The arity of the UserComputation is 2 arguments. Execution will succeed // with 2 arguments, but fail with a different number. - auto result_two_args = - client_->Execute(computation, {param0_data.get(), param1_data.get()}); + auto result_two_args = client_->Execute( + computation, {param0_data.get(), param1_data.get()}, &execution_options_); ASSERT_IS_OK(result_two_args.status()); - auto result_one_arg = client_->Execute(computation, {param0_data.get()}); + auto result_one_arg = + client_->Execute(computation, {param0_data.get()}, &execution_options_); ASSERT_FALSE(result_one_arg.ok()); ASSERT_EQ(result_one_arg.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(result_one_arg.status().error_message(), ContainsRegex("takes 2")); - auto result_zero_args = client_->Execute(computation, {}); + auto result_zero_args = + client_->Execute(computation, {}, &execution_options_); ASSERT_FALSE(result_zero_args.ok()); ASSERT_EQ(result_zero_args.status().code(), tensorflow::error::INVALID_ARGUMENT); @@ -85,35 +86,38 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_IS_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto f32_literal = LiteralUtil::CreateR0(1.1f); + auto f32_literal = Literal::CreateR0(1.1f); auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); - auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto f32_4_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); - auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); + auto u8_4_literal = Literal::CreateR1U8("hola"); auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); // Match - auto status = - client_->Execute(computation, {f32_data.get(), f32_4_data.get()}); + auto status = client_->Execute( + computation, {f32_data.get(), f32_4_data.get()}, &execution_options_); ASSERT_IS_OK(status.status()); // Shape mismatch in parameter 0 - status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}); + status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), ContainsRegex("expects parameter 0")); // Shape mismatch in parameter 1 (rank) - status = client_->Execute(computation, {f32_data.get(), f32_data.get()}); + status = client_->Execute(computation, {f32_data.get(), f32_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), ContainsRegex("expects parameter 1")); // Shape mismatch in parameter 1 (element type) - status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}); + status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), @@ -126,7 +130,6 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index b96bb8f846909589a52269f0d314dbfd0af2be09..3082630505fe9aea9222ed478a1e6504e18231b6 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,18 +37,20 @@ namespace xla { namespace { // Wrapper function that creates a nicer error message (than a bare // ValueOrDie()) if the platform we intend to test is not available. -Client* GetOrCreateLocalClientOrDie(se::Platform* platform) { - StatusOr result = ClientLibrary::GetOrCreateLocalClient(platform); +Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { + StatusOr result = + ClientLibrary::GetOrCreateLocalClient(client_options); TF_CHECK_OK(result.status()) << "could not create local client for testing"; return result.ValueOrDie(); } } // namespace -ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) - : client_(GetOrCreateLocalClientOrDie(platform)) { - *(execution_options_.mutable_debug_options()) = - legacy_flags::GetDebugOptionsFromFlags(); - +ClientLibraryTestBase::ClientLibraryTestBase( + perftools::gputools::Platform* platform, + const LocalClientOptions& client_options) + : client_(GetOrCreateLocalClientOrDie(client_options)), + execution_options_(CreateDefaultExecutionOptions()) { + CHECK_EQ(platform, client_options.platform()); // Disabling constant_folding so that tests (usually written using Constants) // will exercise the intended code paths, instead of being constant folded. // @@ -59,6 +61,15 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) "constant_folding"); } +ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) + : execution_options_(CreateDefaultExecutionOptions()) { + LocalClientOptions default_options; + default_options.set_platform(platform); + client_ = GetOrCreateLocalClientOrDie(default_options); + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "constant_folding"); +} + string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } @@ -71,13 +82,16 @@ StatusOr> ClientLibraryTestBase::Execute( return client_->Execute(computation, arguments, &execution_options_); } +StatusOr ClientLibraryTestBase::ExecuteAsync( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments) { + return client_->ExecuteAsync(computation, arguments, &execution_options_); +} + StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - ComputationBuilder* builder, + const Computation& computation, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout) { - // Build the computation, as a convenience. - TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = @@ -87,6 +101,15 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); +} + std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { @@ -113,14 +136,14 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return LiteralUtil::ToString(*result.ValueOrDie()); + return result.ValueOrDie()->ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); + std::unique_ptr expected_literal = Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -141,18 +164,121 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( error, shape_with_layout)); } +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::Computation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output) { + // Try with no layout requirement. + TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); + verify_output(*actual, ""); + + // Try with all output layouts. + std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + do { + auto layout = ShapeUtil::MakeShapeWithLayout( + expected.shape().element_type(), + AsInt64Slice(expected.shape().dimensions()), minor_to_major); + TF_ASSIGN_OR_RETURN(auto actual, + ExecuteAndTransfer(computation, arguments, &layout)); + verify_output(*actual, tensorflow::strings::StrCat( + "Test with output layout: ", + ShapeUtil::HumanStringWithLayout(layout))); + } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); + return tensorflow::Status::OK(); +} + +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( + const xla::Computation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output, + const Shape* output_with_layout) { + std::vector arguments_with_layout; + std::vector layout_strings; + // This is a recursive function. It's an std::function instead of a lambda + // because it needs to capture itself. The index is the index of the argument + // to try all layouts for. + std::function choose; + choose = [&, this](int64 index) -> tensorflow::Status { + if (index < arguments.size()) { + // Try out all layouts for the operand. + TF_ASSIGN_OR_RETURN(auto literal, + client_->Transfer(*arguments[index], nullptr)); + // Skip tuples because they don't have a rank. + if (ShapeUtil::IsTuple(literal->shape())) { + layout_strings.push_back( + ShapeUtil::HumanStringWithLayout(literal->shape())); + arguments_with_layout.push_back(arguments[index]); + TF_RETURN_IF_ERROR(choose(index + 1)); + arguments_with_layout.pop_back(); + layout_strings.pop_back(); + return tensorflow::Status::OK(); + } + + std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + do { + auto literal_relayout = + literal->Relayout(LayoutUtil::MakeLayout(minor_to_major)); + layout_strings.push_back( + ShapeUtil::HumanStringWithLayout(literal_relayout->shape())); + TF_ASSIGN_OR_RETURN(auto data, + client_->TransferToServer(*literal_relayout)); + arguments_with_layout.push_back(data.get()); + TF_RETURN_IF_ERROR(choose(index + 1)); + arguments_with_layout.pop_back(); + layout_strings.pop_back(); + } while ( + std::next_permutation(minor_to_major.begin(), minor_to_major.end())); + return tensorflow::Status::OK(); + } + + // Every argument has an assigned layout. + TF_ASSIGN_OR_RETURN( + auto actual, + ExecuteAndTransfer( + computation, + tensorflow::gtl::ArraySlice(arguments_with_layout), + output_with_layout)); + string error_message = "Test with input layouts: "; + for (const auto& str : layout_strings) { + tensorflow::strings::StrAppend(&error_message, str, " "); + } + verify_output(*actual, error_message); + return tensorflow::Status::OK(); + }; + + return choose(0); +} + tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { - TF_ASSIGN_OR_RETURN( - auto actual, ExecuteAndTransfer(builder, arguments, shape_with_layout)); + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); if (ShapeUtil::ElementIsFloating(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; } else { TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || expected.shape().element_type() == PRED); } + auto expect_equal = [&](const Literal& actual, const string& error_message) { + LiteralTestUtil::ExpectEqual(expected, actual, error_message); + }; + if (execution_options_.debug_options().xla_test_all_output_layouts()) { + return ComputeAndCompareLiteralWithAllOutputLayouts( + computation, expected, arguments, expect_equal); + } + if (execution_options_.debug_options().xla_test_all_input_layouts()) { + return ComputeAndCompareLiteralWithAllInputLayouts( + computation, expected, arguments, expect_equal, shape_with_layout); + } + TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, + shape_with_layout)); LiteralTestUtil::ExpectEqual(expected, *actual); return tensorflow::Status::OK(); } @@ -161,9 +287,21 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { - TF_ASSIGN_OR_RETURN( - auto actual, ExecuteAndTransfer(builder, arguments, shape_with_layout)); TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape())); + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + auto expect_near = [&](const Literal& actual, const string& error_message) { + LiteralTestUtil::ExpectNear(expected, actual, error, error_message); + }; + if (execution_options_.debug_options().xla_test_all_output_layouts()) { + return ComputeAndCompareLiteralWithAllOutputLayouts(computation, expected, + arguments, expect_near); + } + if (execution_options_.debug_options().xla_test_all_input_layouts()) { + return ComputeAndCompareLiteralWithAllInputLayouts( + computation, expected, arguments, expect_near, shape_with_layout); + } + TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, + shape_with_layout)); LiteralTestUtil::ExpectNear(expected, *actual, error); return tensorflow::Status::OK(); } @@ -179,10 +317,10 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + std::unique_ptr expected_literal = Literal::CreateR1U8(expected); - VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal); - VLOG(1) << "actual: " << LiteralUtil::ToString(*actual); + VLOG(1) << "expected: " << expected_literal->ToString(); + VLOG(1) << "actual: " << actual->ToString(); EXPECT_EQ(expected, actual->u8s_string()); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index f9e1082ebb43ae112c417ff9a71ef8d38b5de900..19c179c4ba250e055912899db42a3e64cbfa9001 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -48,6 +49,10 @@ class ClientLibraryTestBase : public ::testing::Test { explicit ClientLibraryTestBase( perftools::gputools::Platform* platform = nullptr); + // Creates a new ClientLibraryTestBase with custom client options. + ClientLibraryTestBase(perftools::gputools::Platform* platform, + const LocalClientOptions& client_options); + // Returns the name of the test currently being run. string TestName() const; @@ -66,14 +71,23 @@ class ClientLibraryTestBase : public ::testing::Test { // TODO(b/25566808): Add helper that populates a literal from a testdata file. - // Convenience methods for building and running a computation from a builder. + // Convenience methods for building and running a computation with the member + // execution options. Modify execution_options_ in your test if you want to + // customize the options. StatusOr> Execute( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments); + StatusOr ExecuteAsync( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments); StatusOr> ExecuteAndTransfer( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); // Convenience OrDie variants of above methods. std::unique_ptr ExecuteOrDie( @@ -271,6 +285,22 @@ class ClientLibraryTestBase : public ::testing::Test { Client* client_; ExecutionOptions execution_options_; + + private: + // Build and run the computation with all permutations of output layouts. + tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::Computation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output); + // Build and run the computation with all permutations of layouts of all input + // arguments. + tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + const xla::Computation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output, + const Shape* output_with_layout = nullptr); }; template @@ -278,7 +308,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( ComputationBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); + Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -291,7 +321,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); + Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -301,7 +331,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -314,7 +344,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -324,7 +354,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( ComputationBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR2FromArray2D(expected); + Literal::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -337,7 +367,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR2FromArray2D(expected); + Literal::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -347,7 +377,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( ComputationBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR3FromArray3D(expected); + Literal::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -360,7 +390,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR3FromArray3D(expected); + Literal::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -370,7 +400,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( ComputationBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR4FromArray4D(expected); + Literal::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -383,7 +413,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR4FromArray4D(expected); + Literal::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -392,7 +422,7 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR0(value); + std::unique_ptr literal = Literal::CreateR0(value); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -404,7 +434,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR1(values); + std::unique_ptr literal = Literal::CreateR1(values); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -416,7 +446,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); + std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -428,7 +458,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); + std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 1247804dae0effd387d5f276a3d64667bc69e18b..e84a6ce710229043c903c5e50daf33e2f93fa6da 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -47,7 +46,7 @@ TEST_F(ClientTest, ExecuteWithLayout) { auto computation = b.Build(); ASSERT_TRUE(computation.ok()) << computation.status(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, execute_layout); @@ -77,7 +76,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { auto computation = b.Build(); ASSERT_TRUE(computation.ok()) << computation.status(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; // Create a result shape with one element column major and the other row // major. *execution_options.mutable_shape_with_output_layout() = @@ -115,7 +114,6 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index cc3eb0e8d46a8ab13553cb78f58bfc48b16ee862..90767c4a17478d4e7edd6202a8629db5b115381d 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -32,6 +33,20 @@ limitations under the License. namespace xla { +std::unique_ptr CodegenTestBase::CreateNewModuleWithEmbeddedIr( + bool ftz) { + HloModuleConfig config; + auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_embed_ir_in_executable(true); + debug_options.set_xla_gpu_ftz(ftz); + // TODO(b/38354253): Change tests to use Parameters instead of Constants. + debug_options.add_xla_disable_hlo_passes("constant_folding"); + config.set_debug_options(debug_options); + + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} + void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr hlo_module, const string& pattern) { std::unique_ptr executable = @@ -43,8 +58,7 @@ void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr hlo_module, std::unique_ptr CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { return backend_->compiler() - ->Compile(std::move(hlo_module), test_hlo_dumper_, - backend_->default_stream_executor()) + ->Compile(std::move(hlo_module), backend_->default_stream_executor()) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h index 50c0453107095c5fdb6238c88a17b31728b6bf22..fa073cd91ee07462d7aaf40789e87dbc831da95e 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.h +++ b/tensorflow/compiler/xla/tests/codegen_test_base.h @@ -28,7 +28,11 @@ namespace xla { // Tests that verify IR emitted by the CPU/GPU backend is as expected. class CodegenTestBase : public HloTestBase { protected: - CodegenTestBase() {} + // Like HloTestBase::CreateNewModule, but also sets the "embed ir in + // executable" flag to true, since this is needed for codegen tests. + // The optional ftz flags configures whether these modules have their ftz + // option turned on. + std::unique_ptr CreateNewModuleWithEmbeddedIr(bool ftz = false); // Returns the embedded LLVM IR from the given executable. Codegen tests must // override this method, but execution tests do not have to because they do diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 18ea9714d1a8f5f5b127881f657e948d65003ab1..7038afc5b1f5dd388731ae82586fe24ac5476e8b 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -48,10 +47,10 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::unique_ptr result = client_ ->ExecuteAndTransfer(computation, arguments, - /*execution_options=*/nullptr, + /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0(expected_result), + LiteralTestUtil::ExpectNear(*Literal::CreateR0(expected_result), *result, error_spec_); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -62,14 +61,13 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::initializer_list> expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; - auto data_handle = - client_ - ->Execute(computation, arguments, /*execution_options=*/nullptr, - &execution_profile) - .ConsumeValueOrDie(); + auto data_handle = client_ + ->Execute(computation, arguments, + &execution_options_, &execution_profile) + .ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2(expected_result), + LiteralTestUtil::ExpectNear(*Literal::CreateR2(expected_result), *result, error_spec_); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -89,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + client_->TransferToServer(*Literal::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + client_->TransferToServer(*Literal::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + client_->TransferToServer(*Literal::CreateR0(456.0f)) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -205,7 +203,6 @@ XLA_TEST_F(CompilationCacheTest, MutatedComputation) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 13c78fb16331340ae9b3586ac47a071230b73a83..4384c9b31495437db10744ea2b98b5b0b05b7ae4 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -86,7 +85,7 @@ class ComputeConstantTest : public ::testing::Test { ComputationBuilder* builder) { TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, builder)); - return LiteralUtil::Get(*literal, {}); + return literal->Get({}); } bool IsConstant(const ComputationDataHandle& operand, @@ -211,7 +210,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { auto computed = ComputeConstantLiteral(client, computation, &b); ASSERT_TRUE(computed.ok()) << computed.status(); std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); + Literal::CreateR1({4, 6}); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); } } @@ -225,7 +224,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { auto computed = ComputeConstantLiteral(client, computation, &b); ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); + std::unique_ptr expected_literal = Literal::CreateR0(5); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); } } @@ -291,7 +290,6 @@ TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index a7034930bc9493dfc4931a77c05cf87e4d138173..c5d88ad6a08476731b5b09cb4ae16a3e76bbaf98 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -518,8 +517,8 @@ TEST_P(ConcatR2BinaryTest, DoIt) { // concat XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = LiteralUtil::CreateR0(2.f); - auto y_literal = LiteralUtil::CreateR0(3.f); + auto x_literal = Literal::CreateR0(2.f); + auto y_literal = Literal::CreateR0(3.f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); @@ -540,9 +539,9 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { // produces the correct result in rank 1. XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); - auto y_literal = LiteralUtil::CreateR0(1.5f); - auto z_literal = LiteralUtil::CreateR0(5.5f); + auto x_literal = Literal::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); + auto y_literal = Literal::CreateR0(1.5f); + auto z_literal = Literal::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); @@ -568,9 +567,9 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); Array3D x3d(3, 5, 7, 3.14f); - auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); - auto y_literal = LiteralUtil::CreateR0(1.5f); - auto z_literal = LiteralUtil::CreateR0(5.5f); + auto x_literal = Literal::CreateR3FromArray3D(x3d); + auto y_literal = Literal::CreateR0(1.5f); + auto z_literal = Literal::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); @@ -607,7 +606,6 @@ INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 1c065de8ba7663ac2e7b3dcd52298e6587d993f0..7c276c8c8d0c0e97b0dfba7a5d6a6165386e5261 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -113,7 +112,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { ComputationBuilder builder(client_, TestName()); auto constant = builder.ConstantLiteral( - *LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2))); + *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); } @@ -128,8 +127,8 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - auto constant = builder.ConstantLiteral( - *LiteralUtil::CreateR3FromArray3D(array3d)); + auto constant = + builder.ConstantLiteral(*Literal::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -143,7 +142,7 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - Literal input_literal = *LiteralUtil::CreateR4FromArray4D(input_array); + Literal input_literal = *Literal::CreateR4FromArray4D(input_array); { ComputationBuilder builder(client_, TestName()); @@ -161,9 +160,9 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { ComputationBuilder builder(client_, TestName()); - builder.ConstantLiteral(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()})); + builder.ConstantLiteral( + *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), + Literal::CreateR1({2.0, 42}).get()})); std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); @@ -179,7 +178,6 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 6d3797972507f2c17b545c612c0dd839212e5ae5..2d181938ded0804776847772d4bb58bbc5e334f4 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -70,6 +69,24 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } +TEST_F(ConvertTest, ConvertR1PREDToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({true, false, true}); + builder.ConvertElementType(a, S32); + + std::vector expected = {1, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1PREDToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({true, false, true}); + builder.ConvertElementType(a, F32); + + std::vector expected = {1., 0., 1.}; + ComputeAndCompareR1(&builder, expected, {}); +} + XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); @@ -197,7 +214,6 @@ TEST_F(ConvertTest, ConvertReshape) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 0b09416a74771a8a9df804dcae783dc220420fc2..fb50d9b0ebf5b4a6c9d244f699620e2dcb74acaf 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -63,8 +62,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = MakeUnique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -102,7 +100,6 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index ec19469fa66c16cff3d1349b7ccc1d0de94d0b54..a110082f9a52ded5e836fa835e82f790e05df0e0 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -115,10 +114,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -158,10 +157,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -201,10 +200,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -246,10 +245,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -273,10 +272,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -313,21 +312,18 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); std::iota(input_elems.begin(), input_elems.end(), 1.0f); - auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r5 = - LiteralUtil::Reshape(*input_r1, input_dims).ConsumeValueOrDie(); + auto input_r1 = Literal::CreateR1(input_elems); + auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); - auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r5 = - LiteralUtil::Reshape(*filter_r1, filter_dims).ConsumeValueOrDie(); + auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = LiteralUtil::CreateR1( + auto expected_r1 = Literal::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); - auto expected_r5 = - LiteralUtil::Reshape(*expected_r1, {1, 3, 1, 2, 3}).ConsumeValueOrDie(); + auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); auto filter_literal = @@ -344,7 +340,6 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index b5afc2498dace11c57a7099e9a3d32eb2a387984..c8e74aa01a50042b1e5297920cc184b1eeb51fd3 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -1312,20 +1311,19 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { ComputationBuilder builder(client_, TestName()); - auto gradients_flat = LiteralUtil::CreateR1({1}); + auto gradients_flat = Literal::CreateR1({1}); auto gradients_literal = - LiteralUtil::Reshape(*gradients_flat, {1, 1, 1, 1, 1}) - .ConsumeValueOrDie(); + gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto gradients = builder.ConstantLiteral(*gradients_literal); - auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); + auto weights_flat = Literal::CreateR1({1, 10, 100}); auto weights_literal = - LiteralUtil::Reshape(*weights_flat, {1, 1, 1, 1, 3}).ConsumeValueOrDie(); + weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto weights = builder.ConstantLiteral(*weights_literal); - auto expected_flat = LiteralUtil::CreateR1({10}); + auto expected_flat = Literal::CreateR1({10}); auto expected_literal = - LiteralUtil::Reshape(*expected_flat, {1, 1, 1, 1, 1}).ConsumeValueOrDie(); + expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto mirrored_weights = builder.Rev(weights, {2, 3, 4}); builder.ConvWithGeneralPadding(gradients, mirrored_weights, @@ -1337,21 +1335,19 @@ TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { ComputationBuilder builder(client_, TestName()); - auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); + auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); auto activations_literal = - LiteralUtil::Reshape(*activations_flat, {1, 1, 1, 1, 4}) - .ConsumeValueOrDie(); + activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); auto activations = builder.ConstantLiteral(*activations_literal); - auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); + auto gradients_flat = Literal::CreateR1({100, 10, 1}); auto gradients_literal = - LiteralUtil::Reshape(*gradients_flat, {1, 1, 1, 1, 3}) - .ConsumeValueOrDie(); + gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto gradients = builder.ConstantLiteral(*gradients_literal); - auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); + auto expected_flat = Literal::CreateR1({13, 24, 130}); auto expected_literal = - LiteralUtil::Reshape(*expected_flat, {1, 1, 1, 1, 3}).ConsumeValueOrDie(); + expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto forward_conv = builder.ConvGeneralDilated( activations, gradients, @@ -1370,7 +1366,6 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 4c2413d0fe43d486ebf306fc51601467d6ebf7fd..76ae280f1a0f309d9aa159079827a7e2c7e833d7 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -58,39 +57,34 @@ class CopyOpTest : public HloTestBase { tensorflow::gtl::ArraySlice permutation); }; -TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*LiteralUtil::CreateR0(true)); -} +TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0(true)); } -TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*LiteralUtil::CreateR1({})); -} +TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1({})); } TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); + TestCopyOp(*Literal::CreateR1({1, 2, 3})); } TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4( + TestCopyOp(*Literal::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(*Literal::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } TEST_F(CopyOpTest, CopyParameterScalar) { auto builder = HloComputation::Builder(TestName()); // Copy literal to device to use as parameter. - auto literal = LiteralUtil::CreateR0(42.0); + auto literal = Literal::CreateR0(42.0); Shape shape = literal->shape(); auto constant_device_base = TransferToDevice(*literal); @@ -112,7 +106,7 @@ TEST_F(CopyOpTest, CopyParameterScalar) { TEST_F(CopyOpTest, CopyConstantR2Twice) { auto builder = HloComputation::Builder(TestName()); - auto literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto literal = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -134,7 +128,7 @@ TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal->mutable_shape()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); @@ -170,7 +164,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + std::unique_ptr literal = Literal::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -204,7 +198,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + std::unique_ptr literal = Literal::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -247,7 +241,7 @@ using CopyOpClientTest = ClientLibraryTestBase; XLA_TEST_F(CopyOpClientTest, Copy0x0) { Shape in_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {0, 1}); Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); - auto empty = LiteralUtil::CreateFromShape(in_shape); + auto empty = Literal::CreateFromShape(in_shape); ComputationBuilder builder(client_, TestName()); auto param0 = builder.Parameter(0, in_shape, "input"); @@ -263,7 +257,6 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 32232acf6e34517587b80d5091dbb9d603223184..73772fdec02fc95cb6c8e0685037515183478e85 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -68,7 +67,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); @@ -89,7 +88,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { array(1, 1) = 4.0f; auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); + HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(array))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); @@ -105,7 +104,7 @@ XLA_TEST_F(CustomCallTest, auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( + HloInstruction::CreateConstant(Literal::CreateR2FromArray2D( Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues")); @@ -129,7 +128,6 @@ XLA_TEST_F(CustomCallTest, int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 074753bf6f8f9e64626b9ed2015b94b58dfebc87..0c7c3a8ff6656b05041e672cca97b285a4420446 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -42,7 +41,8 @@ class DeallocationTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice arguments) { Computation computation = builder->Build().ConsumeValueOrDie(); auto global_data = - client_->Execute(computation, arguments).ConsumeValueOrDie(); + client_->Execute(computation, arguments, &execution_options_) + .ConsumeValueOrDie(); TF_CHECK_OK(client_->Transfer(*global_data).status()); return global_data; } @@ -143,7 +143,6 @@ XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index fcddffc1e1340028f11b67cbe14537a240120de7..c65f8c0f08bb8a096b020e73a35cdbb70e517b1f 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -48,7 +47,8 @@ class DeconstructTupleTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice arguments) { Computation computation = builder->Build().ConsumeValueOrDie(); auto global_data = - client_->Execute(computation, arguments).ConsumeValueOrDie(); + client_->Execute(computation, arguments, &execution_options_) + .ConsumeValueOrDie(); TF_CHECK_OK(client_->Transfer(*global_data).status()); return global_data; } @@ -67,9 +67,9 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { // Try copying the elements back and comparing it auto handles = result_status.ConsumeValueOrDie(); std::unique_ptr literal; - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); } @@ -89,17 +89,17 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { auto handles2 = result_status2.ConsumeValueOrDie(); std::unique_ptr literal; - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles1[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0])); LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles1[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1])); LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); handles1[0].reset(); handles1[1].reset(); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles2[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0])); LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles2[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1])); LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); } @@ -119,13 +119,13 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { auto handles = result_status.ConsumeValueOrDie(); std::unique_ptr literal; - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[3])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3])); LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); } @@ -145,17 +145,17 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { global_data.reset(); std::unique_ptr literal; - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); } @@ -173,7 +173,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); @@ -205,7 +205,6 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..60953a7421d410722b499625b4ce4b9ca90aa874 --- /dev/null +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" + +namespace xla { +namespace { +TEST_F(ClientLibraryTestBase, DeepGraph) { + // TODO(b/62624812): To trigger the stack overflow this test is + // intended to track, we need to set kDepth to 20000. + // Unfortunately, setting it that high causes the test to time out. + const int kDepth = 200; + ComputationBuilder b(client_, TestName()); + ComputationDataHandle x; + ComputationDataHandle y; + auto x_data = CreateR0Parameter(3, 0, "x", &b, &x); + auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); + ComputationDataHandle z = x; + for (int i = 0; i < kDepth; ++i) { + z = b.Add(z, y); + } + ComputeAndCompareR0(&b, /*expected=*/kDepth + 3, + {x_data.get(), y_data.get()}); +} +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 754eec1b1edc286b98d02f70c8e5661523bd85de..59ee0073388fe824ee9bc92819c9d10eca624473 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -20,10 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -186,14 +183,14 @@ void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, bool rhs_row_major) { std::unique_ptr> lhs_data = MakeLinspaceArray2D(0.0, 1.0, M, K); - std::unique_ptr lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + std::unique_ptr lhs_lit = Literal::CreateR2FromArray2DWithLayout( *lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))); auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie(); std::unique_ptr> rhs_data = MakeLinspaceArray2D(0.0, 1.0, K, N); - std::unique_ptr rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + std::unique_ptr rhs_lit = Literal::CreateR2FromArray2DWithLayout( *rhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))); auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie(); @@ -380,12 +377,12 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = client_ - ->TransferToServer(*LiteralUtil::CreateR4( + ->TransferToServer(*Literal::CreateR4( {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) .ConsumeValueOrDie(); auto y_data = client_ - ->TransferToServer(*LiteralUtil::CreateR4( + ->TransferToServer(*Literal::CreateR4( {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) .ConsumeValueOrDie(); @@ -416,14 +413,14 @@ TEST_F(DotOperationTest, TransposeFolding) { auto lhs_handle = client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + *Literal::CreateR2FromArray2DWithLayout( *lhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + *Literal::CreateR2FromArray2DWithLayout( *rhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); @@ -460,10 +457,7 @@ TEST_F(DotOperationTest, TransposeFolding) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendLayoutUtilFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index b7bb1792f3b9b96fea5f446c787eb55e2577b01b..9e85e357070c8c7a32bdc8b16b139ceb848114d9 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -45,295 +44,310 @@ namespace { class DynamicSliceTest : public ClientLibraryTestBase { protected: - template + template void TestR1() { // Slice at dimension start. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {0}, {5}, - {0.0, 1.0, 2.0, 3.0, 4.0}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {0}, {5}, {0, 1, 2, 3, 4}); // Slice in the middle. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {3}, - {2.0, 3.0, 4.0}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {3}, {2, 3, 4}); // Slice at dimension boundaries. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {5}, {3}, - {5.0, 6.0, 7.0}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {5}, {3}, {5, 6, 7}); // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4}, - {6.0, 7.0, 0.0, 1.0}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); // Zero element slice. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {0}, {}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {0}, {}); } - template + template void TestR2() { // Slice at dimension start. - RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {0, 0}, {2, 2}, {{1.0f, 2.0f}, {4.0f, 5.0f}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 2}, + {{1, 2}, {4, 5}}); // Slice in the middle. - RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {1, 1}, {2, 1}, {{5.0f}, {8.0f}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1}, + {{5}, {8}}); // Slice at dimension boundaries. - RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {1, 1}, {2, 1}, {{5.0f}, {8.0f}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1}, + {{5}, {8}}); // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {1, 1}, {3, 3}, - {{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, + {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); // Zero element slice: 2x0. - RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {0, 0}, {2, 0}, {{}, {}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 0}, + {{}, {}}); // Zero element slice: 0x2. - RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {0, 0}, {0, 2}, Array2D(0, 2)); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {0, 2}, + Array2D(0, 2)); } - template + template void TestR3() { // R3 Shape: [2, 3, 2] // clang-format off // Slice at dimension start. - RunR3( - {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, - {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, - {0, 0, 0}, {2, 1, 2}, - {{{1.0f, 2.0f}}, {{7.0f, 8.0f}}}); + RunR3( + {{{1, 2}, {3, 4}, {5, 6}}, + {{7, 8}, {9, 10}, {11, 12}}}, + {0, 0, 0}, {2, 1, 2}, + {{{1, 2}}, {{7, 8}}}); // Slice in the middle. - RunR3( - {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, - {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, - {0, 1, 1}, {2, 2, 1}, - {{{4.0f}, {6.0f}}, {{10.0f}, {12.0f}}}); + RunR3( + {{{1, 2}, {3, 4}, {5, 6}}, + {{7, 8}, {9, 10}, {11, 12}}}, + {0, 1, 1}, {2, 2, 1}, + {{{4}, {6}}, {{10}, {12}}}); // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR3( - {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, - {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, - {0, 2, 1}, {2, 2, 1}, - {{{6.0f}, {2.0f}}, {{12.0f}, {8.0f}}}); + RunR3( + {{{1, 2}, {3, 4}, {5, 6}}, + {{7, 8}, {9, 10}, {11, 12}}}, + {0, 2, 1}, {2, 1, 2}, + {{{6, 5}}, {{12, 11}}}); // clang-format on } - template - void RunR1(const std::vector& input_values, + template + void RunR1(tensorflow::gtl::ArraySlice input_values, const std::vector slice_starts, const std::vector& slice_sizes, - const std::vector& expected_values) { + tensorflow::gtl::ArraySlice expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR1(input_values); + auto input = builder.ConstantR1(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR1(&builder, expected_values, {start_data.get()}, - ErrorSpec(0.000001)); + ComputeAndCompareR1(&builder, expected_values, {start_data.get()}); } - template - void RunR2(const Array2D& input_values, + template + void RunR2(const Array2D& input_values, const std::vector slice_starts, const std::vector& slice_sizes, - const Array2D& expected_values) { + const Array2D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR2FromArray2D(input_values); + auto input = builder.ConstantR2FromArray2D(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR2(&builder, expected_values, {start_data.get()}, - ErrorSpec(0.000001)); + ComputeAndCompareR2(&builder, expected_values, {start_data.get()}); } - template - void RunR3(const Array3D& input_values, + template + void RunR3(const Array3D& input_values, const std::vector slice_starts, const std::vector& slice_sizes, - const Array3D& expected_values) { + const Array3D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR3FromArray3D(input_values); + auto input = builder.ConstantR3FromArray3D(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR3(&builder, expected_values, {start_data.get()}, - ErrorSpec(0.000001)); + ComputeAndCompareR3(&builder, expected_values, {start_data.get()}); } }; -XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } +XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } + +XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } + +XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } + +XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } +XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { + // Slice at dimension start. + RunR1({true, false, false, true, false, true, true, false}, {0}, + {5}, {true, false, false, true, false}); + // Slice in the middle. + RunR1({true, false, false, true, false, true, true, false}, {2}, + {3}, {false, true, false}); + // Slice at dimension boundaries. + RunR1({true, false, false, true, false, true, true, false}, {5}, + {3}, {true, true, false}); + // Zero element slice. + RunR1({true, false, false, true, false, true, true, false}, {2}, + {0}, {}); +} -XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, Int32R2Pred) { + // Slice at dimension start. + RunR2( + {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0}, + {2, 2}, {{true, false}, {false, false}}); + // Slice in the middle. + RunR2( + {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1}, + {2, 1}, {{false}, {true}}); + // Slice at dimension boundaries. + RunR2( + {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1}, + {2, 1}, {{false}, {true}}); + // Zero element slice: 2x0. + RunR2( + {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0}, + {2, 0}, {{}, {}}); + // Zero element slice: 0x2. + RunR2( + {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0}, + {0, 2}, Array2D(0, 2)); +} -XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { + // R3 Shape: [2, 3, 2] + // clang-format off + + // Slice at dimension start. + RunR3( + {{{true, false}, {false, true}, {true, true}}, + {{false, true}, {true, false}, {false, false}}}, + {0, 0, 0}, {2, 1, 2}, + {{{true, false}}, {{false, true}}}); + + // Slice in the middle. + RunR3( + {{{true, false}, {false, true}, {true, true}}, + {{false, true}, {true, false}, {false, false}}}, + {0, 1, 1}, {2, 2, 1}, + {{{true}, {true}}, {{false}, {false}}}); + + // clang-format on +} class DynamicUpdateSliceTest : public ClientLibraryTestBase { protected: - template + template void TestR1() { - // clang-format off // Slice at dimension start. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, - {8.0, 9.0, 10.0}, {0}, - {8.0, 9.0, 10.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {0}, + {8, 9, 10, 3, 4, 5, 6, 7}); // Slice in the middle. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, - {8.0, 9.0, 10.0}, {2}, - {0.0, 1.0, 8.0, 9.0, 10.0, 5.0, 6.0, 7.0}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {2}, + {0, 1, 8, 9, 10, 5, 6, 7}); // Slice at dimension boundaries. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, - {8.0, 9.0, 10.0}, {5}, - {0.0, 1.0, 2.0, 3.0, 4.0, 8.0, 9.0, 10.0}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {5}, + {0, 1, 2, 3, 4, 8, 9, 10}); // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, - {8.0, 9.0, 10.0}, {6}, - {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0}); + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, + {0, 1, 2, 3, 4, 5, 8, 9}); // Zero-sized update. - RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, - {}, {2}, - {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); - // clang-format on + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {}, {2}, + {0, 1, 2, 3, 4, 5, 6, 7}); } - template + template void TestR2() { - // clang-format off // Slice at dimension start. - RunR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {{10.0f, 11.0f}}, {0, 0}, - {{10.0f, 11.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {0, 0}, + {{10, 11, 3}, {4, 5, 6}, {7, 8, 9}}); // Slice in the middle. - RunR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {{10.0f, 11.0f}}, {1, 1}, - {{1.0f, 2.0f, 3.0f}, {4.0f, 10.0f, 11.0f}, {7.0f, 8.0f, 9.0f}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {1, 1}, + {{1, 2, 3}, {4, 10, 11}, {7, 8, 9}}); // Slice at dimension boundaries. - RunR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {{10.0f, 11.0f}}, {2, 1}, - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 10.0f, 11.0f}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 1}, + {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {{10.0f, 11.0f}}, {2, 2}, - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}}); + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 10}}); // Zero-sized update. - RunR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, - {{}}, {2, 1}, - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); - // clang-format on + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{}}, {2, 1}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); } - template + template void TestR3() { // R3 Shape: [2, 3, 2] - // clang-format off // Slice at dimension start. - RunR3( - {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, - {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, - {{{13.0f, 14.0f}, {15.0f, 16.0f}}, - {{17.0f, 18.0f}, {19.0f, 20.0f}}}, - {0, 0, 0}, - {{{13.0f, 14.0f}, {15.0f, 16.0f}, {5.0f, 6.0f}}, - {{17.0f, 18.0f}, {19.0f, 20.0f}, {11.0f, 12.0f}}}); + RunR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, + {{{13, 14}, {15, 16}}, {{17, 18}, {19, 20}}}, {0, 0, 0}, + {{{13, 14}, {15, 16}, {5, 6}}, {{17, 18}, {19, 20}, {11, 12}}}); // Slice in the middle. - RunR3( - {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, - {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, - {{{13.0f}, {15.0f}}}, - {1, 1, 1}, - {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, - {{7.0f, 8.0f}, {9.0f, 13.0f}, {11.0f, 15.0f}}}); + RunR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, + {1, 1, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR3( - {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, - {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, - {{{13.0f}, {15.0f}}}, - {1, 2, 1}, - {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, - {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 13.0f}}}); - // clang-format on + RunR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 13}}}); } - template - void RunR1(const std::vector& input_values, - const std::vector& update_values, + template + void RunR1(tensorflow::gtl::ArraySlice input_values, + tensorflow::gtl::ArraySlice update_values, const std::vector slice_starts, - const std::vector& expected_values) { + tensorflow::gtl::ArraySlice expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR1(input_values); - auto update = builder.ConstantR1(update_values); + auto input = builder.ConstantR1(input_values); + auto update = builder.ConstantR1(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR1(&builder, expected_values, {start_data.get()}, - ErrorSpec(0.000001)); + ComputeAndCompareR1(&builder, expected_values, {start_data.get()}); } - template - void RunR2(const Array2D& input_values, - const Array2D& update_values, + template + void RunR2(const Array2D& input_values, + const Array2D& update_values, const std::vector slice_starts, - const Array2D& expected_values) { + const Array2D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR2FromArray2D(input_values); - auto update = builder.ConstantR2FromArray2D(update_values); + auto input = builder.ConstantR2FromArray2D(input_values); + auto update = builder.ConstantR2FromArray2D(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR2(&builder, expected_values, {start_data.get()}, - ErrorSpec(0.000001)); + ComputeAndCompareR2(&builder, expected_values, {start_data.get()}); } - template - void RunR3(const Array3D& input_values, - const Array3D& update_values, + template + void RunR3(const Array3D& input_values, + const Array3D& update_values, const std::vector slice_starts, - const Array3D& expected_values) { + const Array3D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR3FromArray3D(input_values); - auto update = builder.ConstantR3FromArray3D(update_values); + auto input = builder.ConstantR3FromArray3D(input_values); + auto update = builder.ConstantR3FromArray3D(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR3(&builder, expected_values, {start_data.get()}, - ErrorSpec(0.000001)); + ComputeAndCompareR3(&builder, expected_values, {start_data.get()}); } void RunR3Contiguous(std::vector operand_shape, int32 index, @@ -389,28 +403,86 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { std::unique_ptr literal = - LiteralUtil::CreateR3FromArray3D(values); - LOG(INFO) << name << ":" << LiteralUtil::ToString(*literal); + Literal::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << literal->ToString(); } }; -XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { + // Slice at dimension start. + RunR1({false, false, true, true, false, true, true, false}, + {true, true, false}, {0}, + {true, true, false, true, false, true, true, false}); + // Slice in the middle. + RunR1({false, false, true, true, false, true, true, false}, + {false, true, true}, {2}, + {false, false, false, true, true, true, true, false}); + // Slice at dimension boundaries. + RunR1({false, false, true, true, false, true, true, false}, + {false, true, true}, {5}, + {false, false, true, true, false, false, true, true}); + // Zero-sized update. + RunR1({false, false, true, true, false, true, true, false}, {}, + {2}, {false, false, true, true, false, true, true, false}); +} + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R2Pred) { + // Slice at dimension start. + RunR2( + {{false, true, false}, {true, false, true}, {false, true, true}}, + {{true, false}}, {0, 0}, + {{true, false, false}, {true, false, true}, {false, true, true}}); + // Slice in the middle. + RunR2( + {{false, true, false}, {true, false, true}, {false, true, true}}, + {{true, false}}, {1, 1}, + {{false, true, false}, {true, true, false}, {false, true, true}}); + // Slice at dimension boundaries. + RunR2( + {{false, true, false}, {true, false, true}, {false, true, true}}, + {{true, false}}, {2, 1}, + {{false, true, false}, {true, false, true}, {false, true, false}}); + // Zero-sized update. + RunR2( + {{false, true, false}, {true, false, true}, {false, true, true}}, {{}}, + {2, 1}, {{false, true, false}, {true, false, true}, {false, true, true}}); +} + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { + // R3 Shape: [2, 3, 2] + // Slice at dimension start. + RunR3( + {{{true, false}, {false, true}, {true, true}}, + {{false, false}, {false, true}, {true, false}}}, + {{{false, true}, {true, false}}, {{true, true}, {false, true}}}, + {0, 0, 0}, + {{{false, true}, {true, false}, {true, true}}, + {{true, true}, {false, true}, {true, false}}}); + // Slice in the middle. + RunR3({{{true, false}, {false, true}, {true, true}}, + {{false, false}, {false, true}, {true, false}}}, + {{{false}, {true}}}, {1, 1, 1}, + {{{true, false}, {false, true}, {true, true}}, + {{false, false}, {false, false}, {true, true}}}); +} // Tests for simple R3 case where the update is contiguous (i.e. the minor // two dimensions are not sliced). @@ -470,7 +542,7 @@ void BM_DynamicSlice(int num_iters) { ComputationBuilder builder(client, "DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] - auto input_literal = LiteralUtil::CreateR4( + auto input_literal = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); auto input = builder.ConstantLiteral(*input_literal); @@ -488,7 +560,7 @@ void BM_DynamicSlice(int num_iters) { &allocator, 0) .ConsumeValueOrDie(); - auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); + auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( executors[device_ordinal], *start_indices_literal, buffer->mutable_buffer({}))); @@ -521,7 +593,6 @@ BENCHMARK(BM_DynamicSlice); int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 80267e5459d2ab12e3530110c0def699b7695351..90c5aa65592302e076821aaaeaa701ae40c07a6c 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -114,7 +113,6 @@ TEST_F(FloorCeilTest, R0Ceil) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index ee4e92505d9dd1f880473f1e76e5be3f01a1cfb3..9c86c65e5bb5b90072f79f5dee1923fa92b36e21 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -47,7 +46,6 @@ TEST_F(FmaxSimpleTest, FmaxTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index fa36381267e80e3afe693a4d85152d2367956be3..df52f168a8764e2a14e47230cb2a9095d60ddc0f 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -29,7 +31,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -37,10 +41,13 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" using tensorflow::gtl::ArraySlice; +namespace se = ::perftools::gputools; + namespace xla { namespace { @@ -81,7 +88,7 @@ class FusionTest : public HloTestBase { HloInstruction* hlos[4]; for (int i = 0; i < Arity; ++i) { hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(operand_data[i]))); + Literal::CreateR2FromArray2D(operand_data[i]))); } auto answer_shape = ShapeUtil::MakeShape(prim_type, {test_width, test_height}); @@ -107,7 +114,7 @@ class FusionTest : public HloTestBase { ArraySlice(hlos, 0, Arity + 1), HloInstruction::FusionKind::kLoop); - auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); + auto expected = Literal::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); @@ -178,28 +185,27 @@ XLA_TEST_F(FusionTest, Test) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); + Literal::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); + Literal::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.62, 2.72, 3.14}}))); + Literal::CreateR2({{1.62, 2.72, 3.14}}))); auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); + Literal::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); - auto const10 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{true, false, true}, {false, true, false}}))); + Literal::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); + auto const10 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{true, false, true}, {false, true, false}}))); auto select11 = builder.AddInstruction( HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kSelect, const10, add8, const9)); @@ -214,7 +220,7 @@ XLA_TEST_F(FusionTest, Test) { const4, reshape3, add2, const1, const0}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2({{0.5}, {2.72}}), + LiteralTestUtil::ExpectNear(*Literal::CreateR2({{0.5}, {2.72}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -226,11 +232,11 @@ XLA_TEST_F(FusionTest, Parameter) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); + Literal::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-2.0, -2.0, -2.0}}))); + Literal::CreateR2({{-2.0, -2.0, -2.0}}))); // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); @@ -240,7 +246,7 @@ XLA_TEST_F(FusionTest, Parameter) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + LiteralTestUtil::ExpectNear(*Literal::CreateR2({{-1.0, 0.0, 1.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -249,9 +255,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); + Literal::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); // add2 = broadcast(const_vector) + const_array @@ -265,7 +271,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -273,13 +279,13 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto single_element_array = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); + HloInstruction::CreateConstant(Literal::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {}), single_element_array)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(5), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(5), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -287,14 +293,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -302,14 +308,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); + Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -317,13 +323,13 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); + HloInstruction::CreateConstant(Literal::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(7), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -331,13 +337,13 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); + HloInstruction::CreateConstant(Literal::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR3({{{7}}}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR3({{{7}}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -345,13 +351,13 @@ XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); + HloInstruction::CreateConstant(Literal::CreateR0(7))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(7), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -359,14 +365,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -374,14 +380,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -389,14 +395,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -404,14 +410,14 @@ XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1({3, 2, 1}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({3, 2, 1}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -430,10 +436,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -441,7 +447,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(15), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(15), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -449,10 +455,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -462,7 +468,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1({-15}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-15}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -470,9 +476,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); + Literal::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); Window window; ASSERT_TRUE( tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" @@ -512,10 +518,46 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + *Literal::CreateR2({{462, 2145}, {24871, 62491}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } +// When a constant (or other op) which has multiple users is imported +// into a fusion, it should remain shared, rather than being duplicated +// within the fusion. +XLA_TEST_F(FusionTest, SharedConstant) { + auto hlo_module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0}))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); + auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); + auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); + auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction( + {add4, add3, add2, add1, const1}, + HloInstruction::FusionKind::kLoop); + + HloComputation* entry_comp = hlo_module->entry_computation(); + + // entry computation contains the constant(0) and the fusion + EXPECT_EQ(entry_comp->instructions().size(), 2); + + // fused instruction contains the constant(2), the parameter, and 4 adds + EXPECT_EQ(entry_comp->root_instruction()->fused_instructions().size(), 6); + + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({8}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } XLA_TEST_F(FusionTest, Subtract2D) { @@ -568,12 +610,66 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } +void BM_ParallelFusion(int num_iters) { + // Simple element-wise computation to benchmark parallel task partitioning. + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + + const int64 intra_op_parallelism_threads = 16; + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); + auto client = + ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); + + const int64 dim_size = 1024; + // Create a simple fusable elementwise computation. + ComputationBuilder builder(client, "ParallelFusion"); + Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size}); + auto input0 = builder.Broadcast(builder.ConstantR0(1.5f), + AsInt64Slice(input_shape.dimensions())); + auto input1 = builder.Broadcast(builder.ConstantR0(2.0f), + AsInt64Slice(input_shape.dimensions())); + auto input2 = builder.Broadcast(builder.ConstantR0(3.0f), + AsInt64Slice(input_shape.dimensions())); + auto x = builder.Mul(input0, input1); + auto y = builder.Add(x, input2); + auto computation = builder.Build().ConsumeValueOrDie(); + + std::unique_ptr executable = + client->Compile(computation, {}, ExecutableBuildOptions()) + .ConsumeValueOrDie(); + + // Run some warm-up executions. + ExecutableRunOptions options; + options.set_allocator(&allocator); + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + auto result = executable->Run({}, options); + ASSERT_TRUE(result.ok()); + } + + // Run benchmark. + tensorflow::testing::BytesProcessed(static_cast(num_iters) * dim_size * + dim_size * sizeof(float)); + tensorflow::testing::UseRealTime(); + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + auto result = executable->Run({}, options); + ASSERT_TRUE(result.ok()); + } +} + +BENCHMARK(BM_ParallelFusion); + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); @@ -586,5 +682,6 @@ int main(int argc, char** argv) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; } + tensorflow::testing::RunBenchmarks(); return RUN_ALL_TESTS(); } diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index f54fa2256e217e9aa954a10470cd461023be631d..eded2077fce965ab1c729c610764afa2228ca128 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -46,7 +46,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) { builder.ClearOpMetadata(); Shape argument_layout = ShapeUtil::MakeShape(F32, {}); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr executable, local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout, &argument_layout}, diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 5f7b7aa434e29980a7d813dfb57f3b7988ed6e6d..8149e2b7cc72018ef8deb61305bb61ceb77200f9 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -24,14 +24,12 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" @@ -56,17 +54,6 @@ struct HloTestBase::EigenThreadPoolWrapper { HloTestBase::HloTestBase() : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) { - // TODO(b/62411181): get rid of this flag entirely when the usual debug flags - // are piped to all HLO tests. - test_hlo_dumper_ = [](const HloModule& module, const string& label) { - legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags(); - if (flags->xla_hlo_test_generate_hlo_graph) { - const bool show_addresses = true; - const bool show_layouts = true; - hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, - show_addresses, show_layouts); - } - }; VLOG(1) << "executing on platform " << backend_->platform()->Name(); } @@ -77,9 +64,16 @@ HloTestBase::~HloTestBase() { } } +/* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + + auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + // TODO(b/38354253): Change tests to use Parameters instead of Constants. + debug_options.add_xla_disable_hlo_passes("constant_folding"); + + config.set_debug_options(debug_options); + return MakeUnique(TestName(), VersionedComputationHandle(), config); } @@ -91,7 +85,7 @@ StatusOr HloTestBase::Execute( Shape* result_shape) { TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend_->compiler()->Compile(std::move(module), test_hlo_dumper_, + backend_->compiler()->Compile(std::move(module), backend_->default_stream_executor())); se::Stream stream(backend_->default_stream_executor()); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 98bc35ae528970e262740631b283b7dbb6d01538..7f3d163290aba3cfcea1b3204e6c88134e172ed7 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -48,7 +48,7 @@ class HloTestBase : public ::testing::Test { // TestName() for its name; it will also automatically populate its debug // options from command-line flags. It's recommended to use this method to // create all HloModules for tests. - std::unique_ptr CreateNewModule(); + static std::unique_ptr CreateNewModule(); // Executes the given module and returns a global data handle. StatusOr Execute( @@ -104,8 +104,6 @@ class HloTestBase : public ::testing::Test { std::unique_ptr backend_; - Compiler::HloDumper test_hlo_dumper_; - // This vector contains handles of all the device memory allocations performed // by the test. These are deallocated on destruction of the test object. std::vector allocations_; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index eb979ad189db7b238ae6cc393d84d0c6c9fc27d1..0a8208332837545db27bff4e135feb586fc6429a 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -41,20 +41,25 @@ namespace xla { /* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, const Shape& actual) { - ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); - ASSERT_EQ(expected.element_type(), actual.element_type()) - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); - for (int i = 0; i < expected.dimensions_size(); ++i) { - ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } - ASSERT_EQ(expected.tuple_shapes_size(), actual.tuple_shapes_size()); - for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + ASSERT_EQ(ShapeUtil::IsTuple(expected), ShapeUtil::IsTuple(actual)); + if (ShapeUtil::IsTuple(expected)) { + ASSERT_EQ(ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + } + } else { + ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); + ASSERT_EQ(expected.element_type(), actual.element_type()) + << PrimitiveType_Name(expected.element_type()) << " vs " + << PrimitiveType_Name(actual.element_type()); + ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); + for (int i = 0; i < expected.dimensions_size(); ++i) { + ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) + << "mismatch in dimension #" << i + << " expected: " << ShapeUtil::HumanString(expected) + << " actual: " << ShapeUtil::HumanString(actual); + } } } @@ -128,8 +133,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, tensorflow::gtl::MutableArraySlice multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { - NativeT expected_value = LiteralUtil::Get(expected, multi_index); - NativeT actual_value = LiteralUtil::Get(actual, multi_index); + NativeT expected_value = expected.Get(multi_index); + NativeT actual_value = actual.Get(multi_index); ::testing::AssertionResult result = CompareEqual(expected_value, actual_value); return result; // Defines implicit coersion to bool. @@ -147,11 +152,15 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, } // namespace /* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, - const Literal& actual) { - EXPECT_TRUE(Equal(expected, actual)) << "expected:\n" - << LiteralUtil::ToString(expected) - << "\n\tvs actual:\n" - << LiteralUtil::ToString(actual); + const Literal& actual, + const string& message) { + EXPECT_TRUE(Equal(expected, actual)) + << "expected:\n" + << expected.ToString() << "\n\tvs actual:\n" + << actual.ToString() + << (message.empty() + ? "" + : tensorflow::strings::StrCat("\nmessage: ", message)); } /* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, @@ -161,8 +170,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::Equal( const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); AssertEqualShapes(expected.shape(), actual.shape()); std::vector multi_index(expected.shape().dimensions_size(), 0); @@ -210,8 +219,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, ::testing::AssertionResult result = ::testing::AssertionSuccess(); if (!match) { result = ::testing::AssertionFailure() - << "expected: " << LiteralUtil::ToString(expected) - << "\nactual: " << LiteralUtil::ToString(actual); + << "expected: " << expected.ToString() + << "\nactual: " << actual.ToString(); VLOG(1) << result.message(); } return result; @@ -219,8 +228,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, /* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); @@ -247,8 +256,8 @@ class NearComparator { // within the error bound. Emits useful log messages and dumps literals to // temporary files on failure. Returns true if literals match. bool ExpectNear(const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); @@ -282,9 +291,9 @@ class NearComparator { if (num_miscompares_ > 0) { if (!VLOG_IS_ON(1)) { LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) - << " " << LiteralUtil::ToString(expected); + << " " << expected.ToString(); LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) - << " " << LiteralUtil::ToString(actual); + << " " << actual.ToString(); } EXPECT_TRUE(num_miscompares_ == 0) << "\nmax relative mismatch at index " @@ -369,10 +378,9 @@ class NearComparator { void ExpectLiteralsNear(const Literal& expected, const Literal& actual, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { - bool near = - ExpectValuesNear(LiteralUtil::Get(expected, multi_index_), - LiteralUtil::Get(actual, multi_index_)); - LiteralUtil::Set(&miscompares_, multi_index_, !near); + bool near = ExpectValuesNear(expected.Get(multi_index_), + actual.Get(multi_index_)); + miscompares_.Set(multi_index_, !near); } else { for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index_[dimension] = i; @@ -431,14 +439,18 @@ class NearComparator { /* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, const Literal& actual, - const ErrorSpec& error) { - EXPECT_TRUE(Near(expected, actual, error)); + const ErrorSpec& error, + const string& message) { + EXPECT_TRUE(Near(expected, actual, error)) + << (message.empty() + ? "" + : tensorflow::strings::StrCat("\nmessage: ", message)); } /* static */ ::testing::AssertionResult LiteralTestUtil::NearTuple( const Literal& expected, const Literal& actual, const ErrorSpec& error) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { @@ -504,8 +516,7 @@ class NearComparator { *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); // Allocate space in the new literal. - LiteralUtil::Reserve(ShapeUtil::ElementsIn(literal.shape()), - new_literal.get()); + new_literal->Reserve(ShapeUtil::ElementsIn(literal.shape())); // Copy data into new literal, element-by-element. for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { @@ -515,44 +526,36 @@ class NearComparator { IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); switch (literal.shape().element_type()) { case PRED: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U8: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case S32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case S64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case F32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case F64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; default: LOG(FATAL) << "Unhandled primitive element type: " diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index a8b07a2c5d13e93d068cd475cb96a727c8346cc5..f645c4e8dcda73806a4204876716b93aa5fb7185 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -64,7 +64,8 @@ class LiteralTestUtil { const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; // Expects that expected and actual are Equal. - static void ExpectEqual(const Literal& expected, const Literal& actual); + static void ExpectEqual(const Literal& expected, const Literal& actual, + const string& message = ""); // Expects that expected and actual are Not Equal. static void ExpectNotEqual(const Literal& expected, const Literal& actual); @@ -110,7 +111,7 @@ class LiteralTestUtil { // Expects expected and actual to be Near with the given error. static void ExpectNear(const Literal& expected, const Literal& actual, - const ErrorSpec& error); + const ErrorSpec& error, const string& message = ""); // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. @@ -130,6 +131,12 @@ class LiteralTestUtil { std::initializer_list>> expected, const Literal& actual, const ErrorSpec& error); + template + static void ExpectR4Near( + std::initializer_list>>> + expected, + const Literal& actual, const ErrorSpec& error); // Asserts the given literal are within the given error bound to the given // array. Only supported for floating point values. @@ -210,20 +217,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR0(expected), actual); + ExpectEqual(*Literal::CreateR0(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR1Equal( tensorflow::gtl::ArraySlice expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR1(expected), actual); + ExpectEqual(*Literal::CreateR1(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR2(expected), actual); + ExpectEqual(*Literal::CreateR2(expected), actual); } template @@ -231,46 +238,46 @@ template std::initializer_list>> expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR3(expected), actual); + ExpectEqual(*Literal::CreateR3(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR2FromArray2D(expected), actual); + ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR3FromArray3D(expected), actual); + ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR4FromArray4D(expected), actual); + ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR0(expected), actual, error); + ExpectNear(*Literal::CreateR0(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR1Near( tensorflow::gtl::ArraySlice expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR1(expected), actual, error); + ExpectNear(*Literal::CreateR1(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR2(expected), actual, error); + ExpectNear(*Literal::CreateR2(expected), actual, error); } template @@ -278,28 +285,37 @@ template std::initializer_list>> expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR3(expected), actual, error); + ExpectNear(*Literal::CreateR3(expected), actual, error); +} + +template +/* static */ void LiteralTestUtil::ExpectR4Near( + std::initializer_list>>> + expected, + const Literal& actual, const ErrorSpec& error) { + ExpectNear(*Literal::CreateR4(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR2FromArray2D(expected), actual, error); + ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR3FromArray3D(expected), actual, error); + ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); + ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); } template @@ -309,9 +325,9 @@ LiteralTestUtil::CreateRandomLiteral( const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = LiteralUtil::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - literal.get(), [&](tensorflow::gtl::ArraySlice indexes) { + std::unique_ptr literal = Literal::CreateFromShape(shape); + TF_RETURN_IF_ERROR(literal.get()->Populate( + [&](tensorflow::gtl::ArraySlice indexes) { return generator(indexes); })); return std::move(literal); diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index a94f45f73b7d058d6b82f91967f61624a28fea3d..2acf27ed390b0732ba40fcf505c746bd7d8b651e 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,9 +31,8 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + std::unique_ptr literal = Literal::MakeTuple({ + Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); LiteralTestUtil::ExpectEqual(*literal, *literal); } @@ -43,13 +42,11 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + std::unique_ptr lhs = Literal::MakeTuple({ + Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); - std::unique_ptr rhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(64).get(), - LiteralUtil::CreateR0(42).get(), + std::unique_ptr rhs = Literal::MakeTuple({ + Literal::CreateR0(64).get(), Literal::CreateR0(42).get(), }); CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; }; @@ -58,8 +55,8 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto dummy_lambda = [] { - auto two = LiteralUtil::CreateR0(2); - auto four = LiteralUtil::CreateR0(4); + auto two = Literal::CreateR0(2); + auto four = Literal::CreateR0(4); ErrorSpec error(0.001); CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; }; @@ -88,11 +85,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { &literal_proto)); Literal literal(literal_proto); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", LiteralUtil::ToString(literal)); + EXPECT_EQ("2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", LiteralUtil::ToString(literal)); + EXPECT_EQ("4", literal.ToString()); } else if (result.find("miscompares") != string::npos) { - EXPECT_EQ("true", LiteralUtil::ToString(literal)); + EXPECT_EQ("true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index 796f43ea4edc2c4858eb85c7fa8a16bbe8401a4b..4cb383a78dfed8a4867f4b589c6c32db345dfc9f 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -62,7 +61,6 @@ TEST_F(LogTest, LogTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index e4dbd6864a325546fabd88b56acf341b99cb73c8..47a8acbf4ab76758d8387e84eb271c130aba5a64 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -170,7 +169,7 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + std::unique_ptr param0_literal = Literal::CreateR0(42.0); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -184,7 +183,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -199,7 +198,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -213,7 +212,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -226,7 +225,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -240,7 +239,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); + Literal::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -256,7 +255,7 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -273,7 +272,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // maps (lambda (x) (* x 2)) on the result. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -288,7 +287,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + std::unique_ptr param0_literal = Literal::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -385,11 +384,11 @@ TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -434,12 +433,12 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); + Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); + Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -456,15 +455,15 @@ TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); std::unique_ptr param2_literal = - LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); + Literal::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); @@ -517,11 +516,11 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -531,9 +530,10 @@ TEST_F(MapTest, MapOperantionWithBuildError) { StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_THAT(computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: binary op with " - "different element types: f32[] and u16[]")); + EXPECT_THAT( + computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: binary op BINOP_ADD with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all @@ -554,8 +554,8 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { sub_builder->Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + std::unique_ptr param0_literal = Literal::CreateR0(2.0f); + std::unique_ptr param1_literal = Literal::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -581,8 +581,8 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { sub_builder->Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + std::unique_ptr param0_literal = Literal::CreateR0(2.0f); + std::unique_ptr param1_literal = Literal::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -606,7 +606,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) { sub_builder->Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + std::unique_ptr param0_literal = Literal::CreateR0(10.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -622,7 +622,6 @@ TEST_F(MapTestWithFullOpt, MapSquare) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 51261f0ac1c15ee96dd0f749fec35971d73b34f2..9ad9b33176691f361e03af35ede8030d5417592a 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -88,8 +87,8 @@ TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) { builder.Exp(data); std::unique_ptr expected = - LiteralUtil::CreateR2({{2.71828, 1.00000}, // row 0 - {0.36788, 1.64872}}); // row 1 + Literal::CreateR2({{2.71828, 1.00000}, // row 0 + {0.36788, 1.64872}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -116,8 +115,8 @@ TEST_F(MatOpsSimpleTest, MapTwoByTwo) { auto map = builder.Map({data}, add_half); std::unique_ptr expected = - LiteralUtil::CreateR2({{1.5, 0.5}, // row 0 - {-0.5, 1.0}}); // row 1 + Literal::CreateR2({{1.5, 0.5}, // row 0 + {-0.5, 1.0}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -134,8 +133,8 @@ TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) { auto max = builder.Max(lhs, rhs); std::unique_ptr expected = - LiteralUtil::CreateR2({{7.0, 6.0}, // row 0 - {3.0, -4.0}}); // row 1 + Literal::CreateR2({{7.0, 6.0}, // row 0 + {3.0, -4.0}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } @@ -179,16 +178,14 @@ TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { Shape rhs_shape = ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - TF_ASSIGN_OR_ASSERT_OK( + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); ComputationBuilder builder(client_, TestName()); auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); @@ -218,7 +215,6 @@ INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 4929e25c580c427a3f034ccf82e7821222be0d8a..56c15e5ff7256cc75a10733e5934894cc88a34da 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -60,7 +59,6 @@ XLA_TEST_F(SliceTest, Slice3D) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b34e1d7db24fbbc5927102bce94f576f3e6d4947 --- /dev/null +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::gtl::ArraySlice; + +namespace xla { +namespace { + +class MultiOutputFusionTest : public HloTestBase { + public: + ErrorSpec error_spec_{0.0001, 1e-2}; + + protected: + MultiOutputFusionTest() {} + void RunTest2D(bool manual_fusion, int64 size) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + + const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {}); + const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); + + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(8.0f))); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, elem_shape0, "0")); + + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape0, HloOpcode::kAdd, param0, const0)); + + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(elem_shape2, add1, {0, 1})); + + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, elem_shape2, "1")); + + HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape2, HloOpcode::kAdd, broadcast, param1)); + HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape2, HloOpcode::kSubtract, param1, broadcast)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2)); + auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); + + if (manual_fusion) { + auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( + ArraySlice({sub, add2}, 0, 2))); + auto gte0 = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0)); + auto gte1 = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 1)); + TF_CHECK_OK(dot->ReplaceOperandWith(0, gte0)); + TF_CHECK_OK(dot->ReplaceOperandWith(1, gte1)); + + CHECK_NE( + computation->CreateFusionInstruction( + {tuple, sub, add2, broadcast}, HloInstruction::FusionKind::kLoop), + nullptr); + } + + Literal input; + input.PopulateWithValue(2.5f, {size, size}); + auto p1 = TransferToDevice(input); + auto p0 = TransferToDevice(*Literal::CreateR0(-9.0f)); + + Literal expect; + expect.PopulateWithValue(size * 1.5f * 3.5f, {size, size}); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + } + + void RunTest1D(bool manual_fusion, int size) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + + const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size}); + const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, elem_shape_F32, "0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, elem_shape_U8, "1")); + + HloInstruction* param0_U8 = builder.AddInstruction( + HloInstruction::CreateConvert(elem_shape_U8, param0)); + HloInstruction* param1_F32 = builder.AddInstruction( + HloInstruction::CreateConvert(elem_shape_F32, param1)); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape_F32, HloOpcode::kAdd, param0, param1_F32)); + HloInstruction* sub_U8 = + builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape_U8, HloOpcode::kSubtract, param0_U8, param1)); + HloInstruction* sub = builder.AddInstruction( + HloInstruction::CreateConvert(elem_shape_F32, sub_U8)); + + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {size, 1}), add)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kDot, sub, reshape)); + auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); + + if (manual_fusion) { + auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( + ArraySlice({sub_U8, add}, 0, 2))); + + auto gte0 = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0)); + auto gte1 = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape_F32, tuple, 1)); + TF_CHECK_OK(sub->ReplaceOperandWith(0, gte0)); + TF_CHECK_OK(reshape->ReplaceOperandWith(0, gte1)); + + CHECK_NE(computation->CreateFusionInstruction( + {tuple, sub_U8, add, param0_U8, param1_F32}, + HloInstruction::FusionKind::kLoop), + nullptr); + } + + Literal input0, input1; + input0.PopulateWithValue(2.5f, {size}); + input1.PopulateWithValue(1, {size}); + auto p0 = TransferToDevice(input0); + auto p1 = TransferToDevice(input1); + + Literal expect = *Literal::CreateR0(size * 1.5f * 3.5f); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + } +}; + +XLA_TEST_F(MultiOutputFusionTest, 2DNofusion) { RunTest2D(false, 5); } +XLA_TEST_F(MultiOutputFusionTest, 2DFusion) { RunTest2D(true, 5); } +XLA_TEST_F(MultiOutputFusionTest, 2DFusionSize129) { RunTest2D(true, 129); } +XLA_TEST_F(MultiOutputFusionTest, DiffentTypesNoFusion) { RunTest1D(false, 8); } +XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); } + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 4922bbf21c447e4db193e63919d4df5f8079e3be..e270a0477fe140b75b6d4ddffb5d4d98ced2171d 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -183,8 +182,8 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); - auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = LiteralUtil::Relayout(*input, layout); + auto input = Literal::CreateR4FromArray4D(input_array); + input = input->Relayout(layout); b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); @@ -228,8 +227,8 @@ XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 0, 0, 0) = 1.0f; input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; - auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = LiteralUtil::Relayout(*input, layout); + auto input = Literal::CreateR4FromArray4D(input_array); + input = input->Relayout(layout); b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); @@ -308,7 +307,7 @@ XLA_TEST_F(PadTest, Large2DPad) { auto ones = MakeUnique>(4, 4); ones->Fill(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*ones); + auto input_literal = Literal::CreateR2FromArray2D(*ones); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -334,7 +333,7 @@ XLA_TEST_F(PadTest, AllTypes2DPad) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(0.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -365,7 +364,7 @@ XLA_TEST_F(PadTest, High2DPad) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -397,7 +396,7 @@ XLA_TEST_F(PadTest, NegativePadding2D) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -429,7 +428,7 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -453,7 +452,7 @@ XLA_TEST_F(PadTest, ReducePad) { auto ones = MakeUnique>(2, 2, 2, 2); ones->Fill(1.0); - auto input_literal = LiteralUtil::CreateR4FromArray4D(*ones); + auto input_literal = Literal::CreateR4FromArray4D(*ones); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -470,7 +469,6 @@ XLA_TEST_F(PadTest, ReducePad) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 3e1bfcd3090df6df69e344c157390a41476f17a4..a7692fceb4751a4e81851c382be0371efbff8dc8 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -44,8 +43,7 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR0(3.14159f); + std::unique_ptr param0_literal = Literal::CreateR0(3.14159f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -57,7 +55,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -70,7 +68,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -83,7 +81,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { ComputationBuilder builder(client_, TestName()); string str("hello world"); - std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + std::unique_ptr param0_literal = Literal::CreateR1U8(str); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -96,7 +94,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); + Literal::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -108,7 +106,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + std::unique_ptr param0_literal = Literal::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -124,12 +122,12 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = builder.Parameter(1, literal1->shape(), "param1"); @@ -155,7 +153,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + std::unique_ptr literal = Literal::CreateR0(3.14159f); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -173,12 +171,12 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = builder.Parameter(1, literal1->shape(), "param1"); @@ -193,12 +191,11 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = - LiteralUtil::CreateR1({10, 20, 30}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20, 30}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); @@ -238,7 +235,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + std::unique_ptr literal = Literal::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); ComputationDataHandle param = @@ -268,9 +265,9 @@ XLA_TEST_F(ParamsTest, std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR1({4, 5, 6}).get(), + ->TransferToServer(*Literal::MakeTuple({ + Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR1({4, 5, 6}).get(), })) .ConsumeValueOrDie(); @@ -282,7 +279,7 @@ XLA_TEST_F(ParamsTest, // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 2}, {3, 4}, }); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -296,7 +293,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 3}, {2, 4}, }); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); @@ -309,7 +306,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 3}, {2, 4}, }); const Shape original = literal->shape(); @@ -322,7 +319,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { std::reverse(original_layout.begin(), original_layout.end()); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); - ASSERT_EQ(2, LiteralUtil::Get(*literal, {0, 1})); + ASSERT_EQ(2, literal->Get({0, 1})); } // Use the original shape in building the computation. ComputationBuilder builder(client_, TestName()); @@ -344,7 +341,6 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/tools/ci_build/builds/tensorboard.sh b/tensorflow/compiler/xla/tests/plugin.bzl old mode 100755 new mode 100644 similarity index 60% rename from tensorflow/tools/ci_build/builds/tensorboard.sh rename to tensorflow/compiler/xla/tests/plugin.bzl index 77bd29c09f8a1009708ed2bd95987df954fd4a77..1b10c778ce3587d9b3f345a92abbb4da92bcad9b --- a/tensorflow/tools/ci_build/builds/tensorboard.sh +++ b/tensorflow/compiler/xla/tests/plugin.bzl @@ -1,5 +1,4 @@ -#!/usr/bin/env bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Additional XLA devices to be included in the unit test suite.""" -set -e - -export LAUNCHPAD_CHROME=${LAUNCHPAD_CHROME:-$(which chromium-browser)} - -cd tensorflow/tensorboard - -# Install all js dependencies (tooling via npm, frontend assets via bower) -npm run prepare +# Example: +# +# plugins = { +# "foo": { +# "deps": [ +# "//tensorflow/compiler/plugin/foo:foo_lib", +# "//tensorflow/compiler/plugin/foo:test_macros", +# ], +# "copts": [], +# "tags": [], +# "args": [] +# }, +# } -npm run compile +plugins = {} -# Run wct in headless chrome using xvfb -xvfb-run ./node_modules/web-component-tester/bin/wct --skip-plugin=sauce diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index b031725d8abd897c83e40a3514bcccb7d7d76acf..d865297ae612f614f45aa6b4b226e15ee154ed2f 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -142,7 +141,6 @@ TEST_F(PredTest, AnyR2False) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 5117478bfd55093a82a5fa361feb5cf59fd68fd1..0a2d4c763d204478683520f339574ca7738d8650 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -58,11 +57,10 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - LiteralUtil::EachCell(*actual, - [=](tensorflow::gtl::ArraySlice, T value) { - EXPECT_LE(a, value); - EXPECT_LT(value, b); - }); + actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { + EXPECT_LE(a, value); + EXPECT_LT(value, b); + }); } void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { @@ -70,17 +68,16 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { auto shape = ShapeUtil::MakeShape(U32, dims); builder.RngBernoulli(builder.ConstantR0(p), shape); - TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); - ExecutionOptions execution_options; + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(42); - TF_ASSIGN_OR_ASSERT_OK( - auto actual, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options)); + TF_ASSERT_OK_AND_ASSIGN( + auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options)); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); int32 sum = 0; - LiteralUtil::EachCell( - *actual, [&sum](tensorflow::gtl::ArraySlice, uint32 value) { + actual->EachCell( + [&sum](tensorflow::gtl::ArraySlice, uint32 value) { EXPECT_TRUE(value == 0 || value == 1); sum += value; }); @@ -124,10 +121,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); std::vector counts(range_size, 0); - LiteralUtil::EachCell( - *actual, [&counts](tensorflow::gtl::ArraySlice, int32 value) { - ++counts[value]; - }); + actual->EachCell([&counts](tensorflow::gtl::ArraySlice, + int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); @@ -170,23 +165,22 @@ XLA_TEST_F(PrngTest, MapUsingRng) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr param0_data, - client_->TransferToServer(*param0_literal)); + Literal::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, + client_->TransferToServer(*param0_literal)); auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto fn = build_sum_rng(builder); builder.Map({param0}, fn); - TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(125); - TF_ASSIGN_OR_ASSERT_OK( - auto actual, - client_->ExecuteAndTransfer(computation, - /*arguments=*/{param0_data.get()}, - &execution_options)); + TF_ASSERT_OK_AND_ASSIGN( + auto actual, client_->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get()}, &execution_options)); EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size()); for (int i = 0; i < param0_literal->f32s_size(); ++i) { @@ -209,47 +203,45 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { return builder.Build(); }; - ExecutionOptions execution_options1; + ExecutionOptions execution_options1 = execution_options_; execution_options1.set_seed(42); - ExecutionOptions execution_options2; + ExecutionOptions execution_options2 = execution_options_; execution_options2.set_seed(65); std::unique_ptr result1; { - TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); - TF_ASSIGN_OR_ASSERT_OK( - result1, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options1)); + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options1)); } std::unique_ptr result2; std::unique_ptr result3; { - TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); - TF_ASSIGN_OR_ASSERT_OK( - result2, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options1)); - TF_ASSIGN_OR_ASSERT_OK( - result3, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options1)); + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result2, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options1)); + TF_ASSERT_OK_AND_ASSIGN( + result3, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options1)); } std::unique_ptr result4; std::unique_ptr result5; std::unique_ptr result6; { - TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); - TF_ASSIGN_OR_ASSERT_OK( - result4, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options2)); - TF_ASSIGN_OR_ASSERT_OK( - result5, client_->ExecuteAndTransfer(computation, /*arguments=*/{})); - TF_ASSIGN_OR_ASSERT_OK( - result6, client_->ExecuteAndTransfer(computation, /*arguments=*/{})); + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result4, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options2)); + TF_ASSERT_OK_AND_ASSIGN( + result5, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options_)); + TF_ASSERT_OK_AND_ASSIGN( + result6, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options_)); } LiteralTestUtil::ExpectEqual(*result1, *result2); @@ -273,13 +265,23 @@ XLA_TEST_F(PrngTest, TenValuesN01) { // TODO(b/25995601): Test that resultant values are reasonable } +XLA_TEST_F(PrngTest, RngUniformCrash) { + ComputationBuilder builder(client_, TestName()); + + // This used to crash XLA during LLVM IR generation for CPUs. + auto rng_uniform = builder.RngUniform(builder.ConstantR0(0), + builder.ConstantR0(1000 * 1000), + ShapeUtil::MakeShape(S32, {})); + SetSeed(0); + ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index 4a02567a1a2ea8014cceca085c3d3d8589d6500f..0078733e197685fea575e78b8435485ea9de4926 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -46,7 +45,6 @@ TEST_F(QueryInferredShapeTest, OnePlusOneShape) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..527205bbb0d8d6069ec1450a3cade1663b85616e --- /dev/null +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -0,0 +1,344 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Tests to confirm that the ReducePrecision operation produces the expected +// numerical values. +class ReducePrecisionAccuracyTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { +}; + +// For reduction to IEEE-f16, we want to test the following cases, in both +// positive and negative variants. (Note: IEEE-f16 is 5 exponent bits and 10 +// mantissa bits.) +// +// Vectors of exponent and mantissa sizes to test. We want to test IEEE-f32 (a +// no-op), IEEE-f16, and exponent-reduction-only and mantissa-reduction-only +// variants of IEEE-f16. +static const int exponent_sizes[] = {8, 5, 5, 8}; +static const int mantissa_sizes[] = {23, 10, 23, 10}; + +string TestDataToString(const ::testing::TestParamInfo data) { + int i = data.param; + return tensorflow::strings::StrCat(exponent_sizes[i], "_exponent_bits_", + mantissa_sizes[i], "_mantissa_bits"); +} + +// The FPVAL macro allows us to write out the binary representation of the +// input and expected values in a more readable manner. The mantissa bits +// are separated into the "high" bits (retained with reduction to IEEE-f16) +// and the "low" bits (truncated with reduction to IEEE-f16). +#define FPVAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \ + ((0b##EXPONENT << 23) + (0b##HIGH_MANTISSA << 13) + (0b##LOW_MANTISSA)) + +// Each element in the test-value array consists of four numbers. The first is +// the input value and the following are the expected output values for the +// various precision-reduction cases. +static const uint32_t test_values[][4] = { + // True zero. + { + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000) // 0.0 + }, + // Largest exponent that underflows to zero. + { + FPVAL(01110000, 0000000000, 0000000000000), // 3.05176e-05 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(01110000, 0000000000, 0000000000000) // 3.05176e-05 + }, + // Largest value that rounds to a denormal and thus clamps to zero. + { + FPVAL(01110000, 1111111111, 0111111111111), // 6.10203e-05 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(01110000, 1111111111, 0000000000000) // 6.10054e-05 + }, + // Smallest value that doesn't underflow to zero, due to mantissa rounding + // up and incrementing the exponent out of the denormal range. + { + FPVAL(01110000, 1111111111, 1000000000000), // 6.10203e-05 + FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05 + }, + // Smallest value that doesn't underflow to zero even without mantissa + // rounding. + { + FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 + FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 + FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 + FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05 + }, + // One (to make sure bias-handling is done correctly. + { + FPVAL(01111111, 0000000000, 0000000000000), // 1.0 + FPVAL(01111111, 0000000000, 0000000000000), // 1.0 + FPVAL(01111111, 0000000000, 0000000000000), // 1.0 + FPVAL(01111111, 0000000000, 0000000000000) // 1.0 + }, + // Values in a space where ties round down due to ties-to-even: + // Value with highest mantissa that rounds down. + { + FPVAL(01111111, 0000000000, 1000000000000), // 1.00049 + FPVAL(01111111, 0000000000, 0000000000000), // 1.0 + FPVAL(01111111, 0000000000, 1000000000000), // 1.00049 + FPVAL(01111111, 0000000000, 0000000000000) // 1.0 + }, + // Value with lowest mantissa that rounds up. + { + FPVAL(01111111, 0000000000, 1000000000001), // 1.00049 + FPVAL(01111111, 0000000001, 0000000000000), // 1.00098 + FPVAL(01111111, 0000000000, 1000000000001), // 1.00049 + FPVAL(01111111, 0000000001, 0000000000000) // 1.00098 + }, + // Values in a space where ties round up due to ties-to-even: + // Value with highest mantissa that rounds down. + { + FPVAL(01111111, 0000000001, 0111111111111), // 1.00146 + FPVAL(01111111, 0000000001, 0000000000000), // 1.00098 + FPVAL(01111111, 0000000001, 0111111111111), // 1.00146 + FPVAL(01111111, 0000000001, 0000000000000) // 1.00098 + }, + // Value with a mantissa that rounds up. + { + FPVAL(01111111, 0000000001, 1000000000000), // 1.00146 + FPVAL(01111111, 0000000010, 0000000000000), // 1.00195 + FPVAL(01111111, 0000000001, 1000000000000), // 1.00146 + FPVAL(01111111, 0000000010, 0000000000000) // 1.00195 + }, + // Largest value that does not overflow to infinity. + { + FPVAL(10001110, 1111111111, 0111111111111), // 65520.0 + FPVAL(10001110, 1111111111, 0000000000000), // 65504.0 + FPVAL(10001110, 1111111111, 0111111111111), // 65520.0 + FPVAL(10001110, 1111111111, 0000000000000) // 65504.0 + }, + // Smallest value that overflows to infinity due to mantissa rounding up. + { + FPVAL(10001110, 1111111111, 1000000000000), // 65520.0 + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(10001110, 1111111111, 1000000000000), // 65520.0 + FPVAL(10001111, 0000000000, 0000000000000) // 65536.0 + }, + // Smallest value that overflows to infinity, without mantissa rounding. + { + FPVAL(10001111, 0000000000, 0000000000000), // 65536.0 + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(10001111, 0000000000, 0000000000000) // 65536.0 + }, + // Smallest value that overflows to infinity due to mantissa rounding up, + // even when exponent bits aren't reduced. + { + FPVAL(11111110, 1111111111, 1000000000000), // 3.40199e+38 + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000) // Inf + }, + // True infinity. + { + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000) // Inf + }, + // NAN with a 1 in the preserved bits. + { + FPVAL(11111111, 1000000000, 0000000000000), // NaN + FPVAL(11111111, 1000000000, 0000000000000), // NaN + FPVAL(11111111, 1000000000, 0000000000000), // NaN + FPVAL(11111111, 1000000000, 0000000000000) // NaN + }, + // NAN with a 1 in the truncated bits. + { + FPVAL(11111111, 0000000000, 0000000000001), // NaN + FPVAL(11111111, 0000000000, 0000000000001), // NaN + FPVAL(11111111, 0000000000, 0000000000001), // NaN + FPVAL(11111111, 0000000000, 0000000000001) // NaN + }, + // NAN with all ones, causing rounding overflow. + { + FPVAL(11111111, 1111111111, 1111111111111), // NaN + FPVAL(11111111, 1111111111, 1111111111111), // NaN + FPVAL(11111111, 1111111111, 1111111111111), // NaN + FPVAL(11111111, 1111111111, 1111111111111) // NaN + }}; + +XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { + int index = GetParam(); + int exponent_bits = exponent_sizes[index]; + int mantissa_bits = mantissa_sizes[index]; + + std::vector input_values; + std::vector expected_values; + + const uint32_t sign_bit = 1u << 31; + for (const auto& test_value : test_values) { + // Add positive values. + input_values.push_back(tensorflow::bit_cast(test_value[0])); + expected_values.push_back(tensorflow::bit_cast(test_value[index])); + // Add negative values. We do this in the bitwise representation so as to + // avoid problems with NaN handling. + input_values.push_back( + tensorflow::bit_cast(test_value[0] ^ sign_bit)); + expected_values.push_back( + tensorflow::bit_cast(test_value[index] ^ sign_bit)); + } + + // This is required for proper handling of NaN values. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr a_literal = Literal::CreateR1({input_values}); + std::unique_ptr a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + auto a = builder.Parameter(0, a_literal->shape(), "a"); + + auto reduce_precision = + builder.ReducePrecision(a, exponent_bits, mantissa_bits); + + ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); +} + +INSTANTIATE_TEST_CASE_P(ReducePrecisionAccuracyTest, + ReducePrecisionAccuracyTest, + ::testing::Values(0, 1, 2, 3), TestDataToString); + +// Tests to confirm that the compiler optimization functions add the expected +// ReducePrecisionInsertion passes. +class ReducePrecisionInsertionTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionBeforeFusion) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + auto a = builder.Parameter(0, a_literal->shape(), "a"); + + // Abs doesn't affect resolution. + auto abs = builder.Abs(a); + + // Near 1.0, Log(x) approximates x - 1; this lets us confirm that the + // reduce-precision operation showed up in the correct place in the + // graph. + auto log = builder.Log(abs); + + // Insert precision-reduction after the Abs(x) operation, rounding that + // result to exactly 1.0f. + auto reduce_precision_pass = execution_options_.mutable_debug_options() + ->add_hlo_reduce_precision_options(); + *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto( + HloReducePrecisionOptions::BEFORE_OP_FUSION, 5, 10, + [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); + + ComputeAndCompareR1(&builder, {0.0f}, {a_data.get()}); +} + +XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedAfterFusion) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + auto a = builder.Parameter(0, a_literal->shape(), "a"); + + // These two operations should be fused by any reasonable backend. + auto abs = builder.Abs(a); + auto neg = builder.Neg(abs); + + // Add a pass after operation fusion, suffixing kAbs operations. This + // should not see into the fusion nodes and thus should not affect the + // result. + auto reduce_precision_pass = execution_options_.mutable_debug_options() + ->add_hlo_reduce_precision_options(); + *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto( + HloReducePrecisionOptions::AFTER_OP_FUSION, 5, 10, + [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); + + ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); +} + +XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedAfterFusion) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + auto a = builder.Parameter(0, a_literal->shape(), "a"); + + // These two operations should be fused by any reasonable backend. + auto abs = builder.Abs(a); + auto neg = builder.Neg(abs); + + // Add a pass after operation fusion, suffixing kFusion operations. + auto reduce_precision_pass = execution_options_.mutable_debug_options() + ->add_hlo_reduce_precision_options(); + *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto( + HloReducePrecisionOptions::AFTER_OP_FUSION, 5, 10, + [](const HloOpcode opcode) { return opcode == HloOpcode::kFusion; }); + + ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index ff24177520eab5c6c2061d01223530249050448c..b22866fc84bec6e9e802f18fdea4c17c6f92e40f 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -40,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -64,12 +63,12 @@ class ReduceTest : public ClientLibraryTestBase { ReduceTest() { // Implementation note: laid out z >> y >> x by default. // clang-format off - literal_2d_ = LiteralUtil::CreateR2({ + literal_2d_ = Literal::CreateR2({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 }); - literal_3d_ = LiteralUtil::CreateR3Projected({ + literal_3d_ = Literal::CreateR3Projected({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 @@ -98,7 +97,7 @@ class ReduceTest : public ClientLibraryTestBase { } } std::unique_ptr input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -130,7 +129,7 @@ class ReduceTest : public ClientLibraryTestBase { builder.Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + std::unique_ptr input_literal = Literal::CreateR1(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -157,9 +156,9 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout({minor, major})); + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -185,9 +184,9 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout({minor, major})); + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -203,6 +202,102 @@ class ReduceTest : public ClientLibraryTestBase { ErrorSpec(0.01, 1e-4)); } + template + void ComputeAndCompareGeneric( + typename std::enable_if::value, + ComputationBuilder>::type* builder, + tensorflow::gtl::ArraySlice expected, + tensorflow::gtl::ArraySlice arguments) { + ComputeAndCompareR1(builder, expected, arguments, + ErrorSpec(0.01, 1e-4)); + } + + template + void ComputeAndCompareGeneric( + typename std::enable_if::value, + ComputationBuilder>::type* builder, + tensorflow::gtl::ArraySlice expected, + tensorflow::gtl::ArraySlice arguments) { + ComputeAndCompareR1(builder, expected, arguments); + } + + template + void RunVectorizedReduceTestForType( + const std::function& + reduction_function_generator, + const std::function& + reference_reduction_function, + const NativeT& initial_value) { + const int rows = 64, cols = 128; + const int minor = 1, major = 0; + ComputationBuilder builder(client_, TestName()); + Computation reduction_function = reduction_function_generator(&builder); + const Shape input_shape = ShapeUtil::MakeShape( + xla::primitive_util::NativeToPrimitiveType(), {rows, cols}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(initial_value); + builder.Reduce(input, zero, reduction_function, + /*dimensions_to_reduce=*/{0}); + + Array2D input_data(rows, cols); + input_data.FillUnique(initial_value); + std::unique_ptr input_literal = + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + // NativeT can be bool, and std::vector does not convert to + // ArraySlice. + std::unique_ptr expected(new NativeT[cols]); + for (int64 colno = 0; colno < cols; ++colno) { + NativeT column_result = initial_value; + for (int64 rowno = 0; rowno < rows; ++rowno) { + column_result = reference_reduction_function(column_result, + input_data(rowno, colno)); + } + expected[colno] = column_result; + } + + ComputeAndCompareGeneric( + &builder, tensorflow::gtl::ArraySlice(expected.get(), cols), + {input_global_data.get()}); + } + + void RunVectorizedReduceTest( + const std::function& + reduction_function_generator_for_type, + const std::function& + reference_reduction_function_for_floats, + const std::function& + reference_reduction_function_for_ints, + const std::function& + reference_reduction_function_for_uints, + float floating_point_identity, int32 signed_int_identity, + uint32 unsigned_int_identity) { + // Float version + RunVectorizedReduceTestForType( + [&](ComputationBuilder* builder) { + return reduction_function_generator_for_type(F32, builder); + }, + reference_reduction_function_for_floats, floating_point_identity); + + // Signed int version + RunVectorizedReduceTestForType( + [&](ComputationBuilder* builder) { + return reduction_function_generator_for_type(S32, builder); + }, + reference_reduction_function_for_ints, signed_int_identity); + + // Unsigned int version + RunVectorizedReduceTestForType( + [&](ComputationBuilder* builder) { + return reduction_function_generator_for_type(U32, builder); + }, + reference_reduction_function_for_uints, unsigned_int_identity); + } + std::unique_ptr literal_2d_; std::unique_ptr literal_3d_; uint32 seed_ = 0xdeadbeef; @@ -306,9 +401,8 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = - LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + Literal::CreateR2FromArray2D(input_data); + input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -339,9 +433,8 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = - LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + Literal::CreateR2FromArray2D(input_data); + input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -372,7 +465,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -435,7 +528,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { auto max = CreateScalarMaxComputation(F32, &builder); Array2D input(300, 250); input.FillRandom(214.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + auto input_literal = Literal::CreateR2FromArray2D(input); builder.Reduce(builder.ConstantLiteral(*input_literal), builder.ConstantR0(FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; @@ -450,7 +543,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { auto min = CreateScalarMinComputation(F32, &builder); Array2D input(150, 130); input.FillRandom(214.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + auto input_literal = Literal::CreateR2FromArray2D(input); builder.Reduce(builder.ConstantLiteral(*input_literal), builder.ConstantR0(FLT_MAX), min, {0, 1}); @@ -460,6 +553,32 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { ComputeAndCompareR0(&builder, input_min, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { + ComputationBuilder builder(client_, TestName()); + Array2D input({{1}, {2}}); + auto min = CreateScalarMinComputation(U32, &builder); + auto input_literal = Literal::CreateR2FromArray2D(input); + auto initial_value = + builder.ConstantR0(std::numeric_limits::max()); + + builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, min, + {0, 1}); + ComputeAndCompareR0(&builder, 1, {}); +} + +XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { + ComputationBuilder builder(client_, TestName()); + Array2D input({{1}, {2}}); + auto max = CreateScalarMaxComputation(U32, &builder); + auto input_literal = Literal::CreateR2FromArray2D(input); + auto initial_value = + builder.ConstantR0(std::numeric_limits::min()); + + builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, max, + {0, 1}); + ComputeAndCompareR0(&builder, 2, {}); +} + // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { ComputationBuilder builder(client_, TestName()); @@ -571,6 +690,58 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(ReduceTest, VectorizedReduce_Add) { + RunVectorizedReduceTest(CreateScalarAddComputation, + [](float a, float b) { return a + b; }, + [](int32 a, int32 b) { + return static_cast(static_cast(a) + + static_cast(b)); + }, + [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); +} + +XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) { + RunVectorizedReduceTest(CreateScalarMultiplyComputation, + [](float a, float b) { return a * b; }, + [](int32 a, int32 b) { + return static_cast(static_cast(a) * + static_cast(b)); + }, + [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); +} + +XLA_TEST_F(ReduceTest, VectorizedReduce_Max) { + RunVectorizedReduceTest(CreateScalarMaxComputation, + [](float a, float b) { return std::max(a, b); }, + [](int32 a, int32 b) { return std::max(a, b); }, + [](uint32 a, uint32 b) { return std::max(a, b); }, + std::numeric_limits::min(), + std::numeric_limits::min(), + std::numeric_limits::min()); +} + +XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { + RunVectorizedReduceTest(CreateScalarMinComputation, + [](float a, float b) { return std::min(a, b); }, + [](int32 a, int32 b) { return std::min(a, b); }, + [](uint32 a, uint32 b) { return std::min(a, b); }, + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max()); +} + +XLA_TEST_F(ReduceTest, VectorizedReduce_LogicalAnd) { + RunVectorizedReduceTestForType(CreateScalarLogicalAndComputation, + [](bool a, bool b) { return a && b; }, + true); +} + +XLA_TEST_F(ReduceTest, VectorizedReduce_LogicalOr) { + RunVectorizedReduceTestForType(CreateScalarLogicalOrComputation, + [](bool a, bool b) { return a || b; }, + false); +} + class ReduceR3ToR2Test : public ReduceTest, public ::testing::WithParamInterface {}; @@ -580,9 +751,9 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { Array3D input_array(bounds[0], bounds[1], bounds[2]); input_array.FillRandom(3.14f, 0.05); - auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout(GetParam().layout)); + auto input_literal = Literal::CreateR3FromArray3D(input_array); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -630,7 +801,6 @@ INSTANTIATE_TEST_CASE_P( int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index ec7b47bc283538d7d9219610e4297fee8028d07f..9774e409411cd9726c5955be62b166bf4dc3712d 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -58,7 +57,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_strides, Padding padding) { builder_.ReduceWindow( - input, builder_.ConstantLiteral(LiteralUtil::MinValue(F32)), + input, builder_.ConstantLiteral(Literal::MinValue(F32)), CreateScalarMax(), window_dimensions, window_strides, padding); } @@ -67,7 +66,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_strides, Padding padding) { builder_.ReduceWindow(input, - builder_.ConstantLiteral(LiteralUtil::MaxValue(F32)), + builder_.ConstantLiteral(Literal::MaxValue(F32)), CreateScalarMinComputation(F32, &builder_), window_dimensions, window_strides, padding); } @@ -75,6 +74,12 @@ class ReduceWindowTest : public ClientLibraryTestBase { ComputationBuilder builder_; }; +TEST_F(ReduceWindowTest, Min3In5Stride2) { + const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); + ReduceWindowMin(input, {3}, {2}, Padding::kValid); + ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) { Array4D input_array(1, 0, 2, 1); @@ -132,6 +137,26 @@ TEST_F(ReduceWindowTest, Along2ndMinorDim) { ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); } +TEST_F(ReduceWindowTest, AmongMajor2Dims) { + Array4D input_array(4, 4, 6, 8); + input_array.FillWithMinorDimNum(); + + int win_len = 3; + int win_stride = 1; + + Padding padding = Padding::kSame; + const auto input_data_handle = + builder_.ConstantR4FromArray4D(input_array); + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); +} + TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { Array4D input_array(9, 12, 4, 89); input_array.FillRandom(2.0f); @@ -184,202 +209,6 @@ TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); } -// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. -TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmall) { - Array4D input_array(2, 2, 4, 16); - - Array2D yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, - 11.f, 12.f, 13.f, 14.f, 15.f}, - {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f}, - {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, - 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f}, - {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, - 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}}); - input_array.FillWithYX(yx); - - int win_len = 2; - int win_stride = 2; - const auto input = builder_.ConstantR4FromArray4D(input_array); - Padding padding = Padding::kValid; - ReduceWindowAdd(input, {1, 1, win_len, win_len}, - {1, 1, win_stride, win_stride}, padding); - - auto res = ReferenceUtil::ReduceWindow4DAdd( - input_array, 0.0f, {1, 1, win_len, win_len}, - {1, 1, win_stride, win_stride}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); -} - -// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. -TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmallOverlapped) { - constexpr int64 p = 2; - constexpr int64 z = 2; - constexpr int64 y = 4; - constexpr int64 x = 16; - Array4D input_array(p, z, y, x); - - Array2D yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, - 11.f, 12.f, 13.f, 14.f, 15.f}, - {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f}, - {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, - 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f}, - {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, - 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}}); - input_array.FillWithYX(yx); - - int win_len = 4; - int win_stride = 2; - const auto input = builder_.ConstantR4FromArray4D(input_array); - ReduceWindowAdd(input, {1, 1, win_len, win_len}, - {1, 1, win_stride, win_stride}, Padding::kValid); - - // Expected result - Array2D yx_result({{408.f, 440.f, 472.f, 504.f, 536.f, 568.f, 600.f}}); - Array4D expected(p, z, 1, 7); - expected.FillWithYX(yx_result); - ComputeAndCompareR4(&builder_, expected, {}, ErrorSpec(1e-3, 1e-3)); -} - -TEST_F(ReduceWindowTest, MaxTrivial) { - const auto input = builder_.ConstantR1({42}); - ReduceWindowMax(input, {1}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {42}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add3In3) { - const auto input = builder_.ConstantR1({20, 100, 3}); - ReduceWindowAdd(input, {3}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {123}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add4In16Stride4) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - ReduceWindowAdd(input, {4}, {4}, Padding::kValid); - ComputeAndCompareR1(&builder_, {10, 26, 42, 58}, {}, - ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, DISABLED_ON_CPU(DISABLED_ON_GPU(Min3In5Stride2))) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); - ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max3In3) { - const auto input = builder_.ConstantR1({20, 100, 3}); - ReduceWindowMax(input, {3}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add2In3) { - const auto input = builder_.ConstantR1({100, 10, 1}); - ReduceWindowAdd(input, {2}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {110, 11}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); - ReduceWindowAdd(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {11100, 111}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max4In16Stride4) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - ReduceWindowMax(input, {4}, {4}, Padding::kValid); - ComputeAndCompareR1(&builder_, {4, 8, 12, 16}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max4In16Stride3) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - ReduceWindowMax(input, {4}, {3}, Padding::kValid); - ComputeAndCompareR1(&builder_, {4, 7, 10, 13, 16}, {}, - ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max4In16Stride8) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - ReduceWindowMax(input, {4}, {8}, Padding::kValid); - ComputeAndCompareR1(&builder_, {4, 12}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); - ReduceWindowMax(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {10000, 100}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max3In5Stride1) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 101}); - ReduceWindowMax(input, {3}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {10000, 1000, 101}, {}, - ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add3In4Stride2) { - const auto input = builder_.ConstantR1({1000, 100, 10, 1}); - ReduceWindowAdd(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {1110}, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add2In3SamePad) { - const auto input = builder_.ConstantR1({100, 10, 1}); - ReduceWindowAdd(input, {2}, {1}, Padding::kSame); - ComputeAndCompareR1(&builder_, {110, 11, 1}, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add3In3SamePad) { - const auto input = builder_.ConstantR1({100, 10, 1}); - ReduceWindowAdd(input, {3}, {1}, Padding::kSame); - ComputeAndCompareR1(&builder_, {110, 111, 11}, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add3In3Stride3SamePad) { - const auto input = builder_.ConstantR1({100, 10, 1}); - ReduceWindowAdd(input, {3}, {2}, Padding::kSame); - ComputeAndCompareR1(&builder_, {110, 11}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add2x2In2x2Overlapped) { - Array2D input_array({{1.2f, -2.5f, 0.9f, 1.0f}, - {3.7f, 0.2f, -1.0f, -0.2f}, - {-0.4f, 2.7f, 1.1f, 2.2f}, - {0.6f, 1.7f, 1.4f, -0.2f}}); - auto input = builder_.ConstantR2FromArray2D(input_array); - ReduceWindowAdd(input, {2, 2}, {1, 1}, Padding::kValid); - Array2D expected( - {{2.6f, -2.4f, 0.7f}, {6.2f, 3.0f, 2.1f}, {4.6f, 6.9f, 4.5f}}); - ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add2x2In2x2Disjoint) { - Array2D input_array({{1.2f, -2.5f, 0.9f, 1.0f}, - {3.7f, 0.2f, -1.0f, -0.2f}, - {-0.4f, 2.7f, 1.1f, 2.2f}, - {0.6f, 1.7f, 1.4f, -0.2f}}); - auto input = builder_.ConstantR2FromArray2D(input_array); - ReduceWindowAdd(input, {2, 2}, {2, 2}, Padding::kValid); - Array2D expected({ - {2.6f, 0.7f}, {4.6f, 4.5f}, - }); - ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add1x2In2x2Same) { - Array2D input_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); - auto input = builder_.ConstantR2FromArray2D(input_array); - ReduceWindowAdd(input, {1, 2}, {1, 1}, Padding::kSame); - Array2D expected({ - {3.0f, 2.0f}, {7.0f, 4.0f}, - }); - ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); -} - XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { Array3D input_array(2, 1, 2); input_array(0, 0, 0) = 1000; @@ -470,13 +299,621 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { ComputeAndCompareR4(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); } +TEST_F(ReduceWindowTest, R4UnitWindow) { + Array4D input_array(13, 12, 8, 15); + input_array.Fill(1.0f); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); + ComputationDataHandle input = + builder_.Parameter(0, input_literal->shape(), "operand"); + + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, + {1, 4, 1, 1}, padding); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + ComputeAndCompareR4(&builder_, *res, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { + Array4D input_array(2, 1, 27, 119); + input_array.FillRandom(2.0f); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + ComputationDataHandle input = + builder_.Parameter(0, input_literal->shape(), "operand"); + + int win_len = 1; + int stride = 8; + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + ComputeAndCompareR4(&builder_, *res, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { + Array4D input_array(3, 2, 4, 64); + input_array.FillRandom(2.0f); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + ComputationDataHandle input = + builder_.Parameter(0, input_literal->shape(), "operand"); + + int win_len = 3; + int stride = 1; + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + ComputeAndCompareR4(&builder_, *res, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { + Array4D input_array(1, 3, 12, 200); + input_array.FillRandom(2.0f); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + ComputationDataHandle input = + builder_.Parameter(0, input_literal->shape(), "operand"); + + int win_len = 8; + int stride = 5; + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + ComputeAndCompareR4(&builder_, *res, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { + Array4D input_array(6, 4, 10, 130); + input_array.FillRandom(2.0f); + + int win_len = 3; + int win_stride = 2; + + Padding padding = Padding::kSame; + const auto input_data_handle = + builder_.ConstantR4FromArray4D(input_array); + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(ReduceWindowTest, Add24In1152_NoOverlap) { + std::vector input_vector(128 * 9, 1); + const auto input = builder_.ConstantR1(input_vector); + ReduceWindowAdd(input, {32}, {128}, Padding::kValid); + ComputeAndCompareR1(&builder_, {32, 32, 32, 32, 32, 32, 32, 32, 32}, + {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceWindowTest, Add128In128Stride128) { + const auto input = builder_.ConstantR1( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + ReduceWindowAdd(input, {128}, {128}, Padding::kValid); + ComputeAndCompareR1(&builder_, {1088}, {}, ErrorSpec(0.0001)); +} + +// Regression test for a bug that appeared in Inception (b/34784899). +TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { + Array2D input_array(14, 14, 1.0f); + ComputationDataHandle input = + builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {14, 14}); + + int win_len = 3; + int stride = 1; + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding); + + auto res = ReferenceUtil::ReduceWindow2DAdd( + input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); + + ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { + Array2D input_array(6, 4, 1.0f); + ComputationDataHandle input = + builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {6, 4}); + + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); + + auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, + padding); + + ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +enum Reducer { kAdd, kMax }; + +struct R4ReduceWindowTestData { + int64 base_bounds[4]; + int64 window_bounds[4]; + int64 strides[4]; + int64 pad_low[4]; + int64 pad_high[4]; + + Reducer reducer; +}; + +string R4ReduceWindowTestDataToString( + const ::testing::TestParamInfo& data) { + string str = tensorflow::strings::StrCat( + "base_bounds_", + tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "__window_bounds_", + tensorflow::str_util::Join(data.param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(data.param.pad_low, "x"), // + "__pad_high_", tensorflow::str_util::Join(data.param.pad_high, "x"), // + (data.param.reducer == kAdd) ? "add" : "max"); + CHECK(data.param.reducer == kAdd || data.param.reducer == kMax); + + // Test names are not allowed to contain the '-' character. + std::replace(str.begin(), str.end(), '-', 'n'); + return str; +} + +class R4ReduceWindowTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + protected: + void DoIt() { + ComputationBuilder b(client_, TestName()); + const auto& param = GetParam(); + + const float kInitValue = 0.0f; + + Array4D input(param.base_bounds[0], param.base_bounds[1], + param.base_bounds[2], param.base_bounds[3]); + input.FillIota(1); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4D(input); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, + client_->TransferToServer(*input_literal)); + + std::vector> padding(4); + for (int i = 0; i < 4; ++i) { + padding[i] = {param.pad_low[i], param.pad_high[i]}; + } + + auto parameter = b.Parameter(0, input_literal->shape(), "p0"); + auto pad_value = b.ConstantR0(kInitValue); + CHECK(param.reducer == kAdd || param.reducer == kMax); + auto computation = param.reducer == kAdd + ? CreateScalarAddComputation(F32, &b) + : CreateScalarMaxComputation(F32, &b); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/pad_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, + /*padding=*/padding); + + CHECK(param.reducer == kAdd || param.reducer == kMax); + auto reduce_func = param.reducer == kAdd + ? +[](float a, float b) { return a + b; } + : +[](float a, float b) { return std::max(a, b); }; + std::unique_ptr> expected = + ReferenceUtil::ReduceWindow4DGeneric( + /*operand=*/input, + /*init=*/kInitValue, + /*reduce_func=*/reduce_func, + /*window=*/param.window_bounds, + /*stride=*/param.strides, + /*padding=*/padding); + ComputeAndCompareR4(&b, *expected, {input_arg.get()}, + ErrorSpec(1e-3, 1e-3)); + } +}; + +TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); } + +// base_bounds, window_bounds, strides, pad_low, pad_high +const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { + // Minimal edge case. + R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // Zero base bound edge case. + R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With non-1x1 window. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With max instead of add. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kMax}, + + // With stride. + R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140}, + /*window_bounds=*/{3, 2, 1, 1}, + /*strides=*/{2, 4, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With low padding. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{3, 2, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{3, 2, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With high padding. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{3, 2, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{2, 3, 0, 0}, + /*reducer=*/kAdd}, + + // Window touches both sides of the padding simultaneously. + R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{1, 1, 0, 0}, + /*pad_high=*/{1, 1, 0, 0}, + /*reducer=*/kAdd}, + + // Window is entirely in the padding for some positions. + R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{4, 4, 0, 0}, + /*pad_high=*/{4, 4, 0, 0}, + /*reducer=*/kAdd}, + + // Zero base bound with padding edge case. + R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 1, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With stride, low padding and high padding. + R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140}, + /*window_bounds=*/{3, 4, 1, 1}, + /*strides=*/{3, 1, 1, 1}, + /*pad_low=*/{10, 1, 0, 0}, + /*pad_high=*/{2, 3, 0, 0}, + /*reducer=*/kAdd}, + + // With second minor dimension == 9. + R4ReduceWindowTestData{/*base_bounds=*/{2, 3, 9, 127}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With minor dimension == 129. + R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With minor dims reduction and non-overlapped stride. + R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16}, + /*window_bounds=*/{1, 1, 2, 2}, + /*strides=*/{1, 1, 2, 2}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With minor dims reduction and overlapped stride. + R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16}, + /*window_bounds=*/{1, 1, 4, 4}, + /*strides=*/{1, 1, 2, 2}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, +}; + +INSTANTIATE_TEST_CASE_P(R4ReduceWindowTestInstantiation, R4ReduceWindowTest, + ::testing::ValuesIn(kR4ReduceWindowTestValues), + R4ReduceWindowTestDataToString); + +class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; + +XLA_TEST_P(R4ReduceWindowLargeTest, DoIt) { DoIt(); } + +// Test cases that are large/slow/failed. +const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { + R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{1, 1, 0, 0}, + /*pad_high=*/{1, 1, 0, 0}, + /*reducer=*/kMax}, + + R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{1, 1, 0, 0}, + /*reducer=*/kAdd}, +}; + +INSTANTIATE_TEST_CASE_P(R4ReduceWindowLargeTestInstantiation, + R4ReduceWindowLargeTest, + ::testing::ValuesIn(kR4ReduceWindowLargeTestValues), + R4ReduceWindowTestDataToString); + +struct R2ReduceWindowTestData { + int64 base_bounds[2]; + int64 window_bounds[2]; + int64 strides[2]; + int64 layout[2]; + Padding padding; + Reducer reducer; +} kR2TestCases[] = { + {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4}, + /*strides=*/{1, 2}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4}, + /*strides=*/{1, 1}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3}, + /*strides=*/{1, 1}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100}, + /*strides=*/{2, 99}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25}, + /*strides=*/{5, 4}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2}, + /*strides=*/{3, 3}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36}, + /*strides=*/{4, 5}, /*layout=*/{1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + // Regression test for a bug that appeared in Inception (b/34784899). + {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + // Regression test for a bug that appeared in Inception (b/34784899). + {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2}, + /*strides=*/{2, 2}, /*layout=*/{1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, +}; + +string R2ReduceWindowTestDataToString( + const ::testing::TestParamInfo& data) { + string str = tensorflow::strings::StrCat( + "base_bounds_", + tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "__window_bounds_", + tensorflow::str_util::Join(data.param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // + "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // + "__layout_", data.param.layout[0], "_", data.param.layout[1], // + "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + return str; +} + +class R2ReduceWindowTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(R2ReduceWindowTest, Add) { + ComputationBuilder b(client_, TestName()); + const auto& param = GetParam(); + CHECK(param.reducer == kAdd); + + const float kInitValue = 0.0f; + Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); + std::unique_ptr input_literal = + Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, + client_->TransferToServer(*input_literal)); + b.ReduceWindow(/*operand=*/ + b.Parameter(0, input_literal->shape(), "p0"), + /*init_value=*/b.ConstantR0(kInitValue), + /*computation=*/CreateScalarAddComputation(F32, &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto expected = ReferenceUtil::ReduceWindow2DAdd( + /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareR2(&b, *expected, {input_arg.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +INSTANTIATE_TEST_CASE_P(R2ReduceWindowTestInstantiation, R2ReduceWindowTest, + ::testing::ValuesIn(kR2TestCases), + R2ReduceWindowTestDataToString); + +struct R1ReduceWindowTestData { + int64 base_bounds[1]; + int64 window_bounds[1]; + int64 strides[1]; + Padding padding; + Reducer reducer; +} kR1TestCases[] = { + {/*base_bounds=*/{1}, /*window_bounds=*/{1}, + /*strides=*/{1}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{3}, /*window_bounds=*/{3}, + /*strides=*/{1}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{3}, /*window_bounds=*/{2}, + /*strides=*/{1}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{1}, + /*strides=*/{1}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + + {/*base_bounds=*/{16}, /*window_bounds=*/{4}, + /*strides=*/{4}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + + {/*base_bounds=*/{16}, /*window_bounds=*/{4}, + /*strides=*/{3}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30}, + /*strides=*/{27}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7}, + /*strides=*/{64}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32}, + /*strides=*/{56}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{3}, /*window_bounds=*/{2}, + /*strides=*/{1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{3}, + /*strides=*/{2}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{16}, /*window_bounds=*/{4}, + /*strides=*/{3}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, +}; + +string R1ReduceWindowTestDataToString( + const ::testing::TestParamInfo& data) { + string str = tensorflow::strings::StrCat( + "base_bounds_", + tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "__window_bounds_", + tensorflow::str_util::Join(data.param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // + "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // + "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + return str; +} + +class R1ReduceWindowTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(R1ReduceWindowTest, DoIt) { + ComputationBuilder b(client_, TestName()); + const auto& param = GetParam(); + CHECK(param.reducer == kAdd || param.reducer == kMax); + + const float kInitValue = 0.0f; + std::vector input_vector(param.base_bounds[0]); + std::iota(std::begin(input_vector), std::end(input_vector), 0); + std::unique_ptr input_literal = + Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, + client_->TransferToServer(*input_literal)); + + auto computation = param.reducer == kAdd + ? CreateScalarAddComputation(F32, &b) + : CreateScalarMaxComputation(F32, &b); + b.ReduceWindow(/*operand=*/ + b.Parameter(0, input_literal->shape(), "p0"), + /*init_value=*/b.ConstantR0(kInitValue), + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto reduce_func = param.reducer == kAdd + ? +[](float a, float b) { return a + b; } + : +[](float a, float b) { return std::max(a, b); }; + auto expected = ReferenceUtil::ReduceWindow1DGeneric( + /*operand=*/tensorflow::gtl::ArraySlice(input_vector), + /*init=*/kInitValue, + /*reduce_func=*/reduce_func, + /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareR1(&b, tensorflow::gtl::ArraySlice(*expected), + {input_arg.get()}, ErrorSpec(1e-3, 1e-3)); +} + +INSTANTIATE_TEST_CASE_P(R1ReduceWindowTestInstantiation, R1ReduceWindowTest, + ::testing::ValuesIn(kR1TestCases), + R1ReduceWindowTestDataToString); } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 7c6700feef846242cc49e573fee01c0101b05335..cb7f54ea01c2f063db1575bd498634f5107a39c5 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -61,7 +60,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { // Run it. std::unique_ptr literal = - client_->ExecuteAndTransfer(replayed, /*arguments=*/{}) + client_ + ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect 4. @@ -92,15 +92,16 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*LiteralUtil::CreateR0(2)) + client_->TransferToServer(*Literal::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*LiteralUtil::CreateR0(3)) + client_->TransferToServer(*Literal::CreateR0(3)) .ConsumeValueOrDie(); std::unique_ptr literal = client_ ->ExecuteAndTransfer(replayed, - /*arguments=*/{x_data.get(), y_data.get()}) + /*arguments=*/{x_data.get(), y_data.get()}, + &execution_options_) .ConsumeValueOrDie(); // Expect 5. @@ -141,7 +142,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { // Run it. std::unique_ptr literal = - client_->ExecuteAndTransfer(replayed, /*arguments=*/{}) + client_ + ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect result. @@ -154,7 +156,6 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index c9817bc23d821d95e660b359ce72ae6f4dec6c85..3051562455f48625def2840913314b16e8de2b72 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -63,7 +62,6 @@ TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index ae7d07727b1e2c20d629f2abc5e58036060f0cef..6748d196c1a6305cc6e3ff87191d2c96a45bf0e7 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -71,7 +70,7 @@ XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + std::unique_ptr param0_literal = Literal::CreateR0(1.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -99,7 +98,7 @@ XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); + Literal::CreateR2FromArray2D(Array2D(0, 3)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -403,7 +402,7 @@ XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { XLA_TEST_F(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { ComputationBuilder b(client_, TestName()); - auto input = LiteralUtil::CreateR1({83.0f}); + auto input = Literal::CreateR1({83.0f}); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -435,7 +434,7 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); // clang-format off - auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(Array4D{ + auto literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { { {0, 1}, @@ -467,7 +466,7 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }); Computation computation = builder.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); std::unique_ptr actual = @@ -475,12 +474,12 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = - LiteralUtil::CreateR2FromArray2D(expected_array); + Literal::CreateR2FromArray2D(expected_array); LiteralTestUtil::ExpectEqual(*expected, *actual); } XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - std::unique_ptr input = LiteralUtil::CreateR2({ + std::unique_ptr input = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -508,7 +507,7 @@ XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - std::unique_ptr input = LiteralUtil::CreateR2({ + std::unique_ptr input = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -542,7 +541,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -565,7 +564,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -589,7 +588,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -603,7 +602,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); - auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + auto expected = Literal::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); } @@ -615,7 +614,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -626,7 +625,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { /*new_sizes=*/{7, 2, 3, 5}); Computation computation = builder.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); std::unique_ptr output_literal = @@ -642,7 +641,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { } XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { - auto literal_1x2x3x4 = LiteralUtil::CreateR4( + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -655,7 +654,7 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { } XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = LiteralUtil::CreateR4( + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -665,7 +664,7 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = LiteralUtil::CreateR4( + auto expected_2x4x3x1 = Literal::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -689,7 +688,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -698,9 +697,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -718,7 +717,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -727,9 +726,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -747,7 +746,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -756,9 +755,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -777,7 +776,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -786,9 +785,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -806,7 +805,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -815,9 +814,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal), - input_literal->shape().layout()); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) + ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -831,7 +830,6 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 5ca9702380f4e37b6ba90459222faf832472bbf7..2f72fc0729a8634456986f294bd26de2c37a5212 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -159,7 +158,6 @@ TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 05ce22fc359d5c805840e0f07f645cfb8ffb7786..5b4c05c673339a455c9e58d81c73ede182e0f110 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/packed_literal_reader.h" @@ -66,8 +65,8 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0, LiteralUtil::Get(*actual, {0})); - EXPECT_EQ(24.0, LiteralUtil::Get(*actual, {1})); + EXPECT_EQ(42.0, actual->Get({0})); + EXPECT_EQ(24.0, actual->Get({1})); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { @@ -96,10 +95,10 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, LiteralUtil::Get(*actual, {0, 0})); - EXPECT_EQ(24.0f, LiteralUtil::Get(*actual, {0, 1})); - EXPECT_EQ(64.0f, LiteralUtil::Get(*actual, {1, 0})); - EXPECT_EQ(46.0f, LiteralUtil::Get(*actual, {1, 1})); + EXPECT_EQ(42.0f, actual->Get({0, 0})); + EXPECT_EQ(24.0f, actual->Get({0, 1})); + EXPECT_EQ(64.0f, actual->Get({1, 0})); + EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); LiteralTestUtil::ExpectEqual(*round_tripped, *actual); @@ -131,10 +130,10 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, LiteralUtil::Get(*actual, {0, 0})); - EXPECT_EQ(24.0f, LiteralUtil::Get(*actual, {1, 0})); - EXPECT_EQ(64.0f, LiteralUtil::Get(*actual, {0, 1})); - EXPECT_EQ(46.0f, LiteralUtil::Get(*actual, {1, 1})); + EXPECT_EQ(42.0f, actual->Get({0, 0})); + EXPECT_EQ(24.0f, actual->Get({1, 0})); + EXPECT_EQ(64.0f, actual->Get({0, 1})); + EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); LiteralTestUtil::ExpectEqual(*round_tripped, *actual); @@ -146,7 +145,6 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index f0760241cdb4e555f3536d024d278c87376bb4d3..e6a6b7b37a4308f2c00f35ae8d3013a59f6c05e7 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -48,62 +47,61 @@ class RoundTripTransferTest : public ClientLibraryTestBase { }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0(42)); + RoundTripTest(*Literal::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0(42.0)); + RoundTripTest(*Literal::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1({})); + RoundTripTest(*Literal::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); + RoundTripTest(*Literal::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest(*Literal::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(*Literal::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4({{ + RoundTripTest(*Literal::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -111,36 +109,33 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(*Literal::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), + Literal::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({}).get(), + Literal::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), - LiteralUtil::CreateR1({2, 3}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR0(1.0).get(), + Literal::CreateR1({2, 3}).get()})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); + RoundTripTest(*Literal::CreateR4FromArray4D(array4d)); } } // namespace @@ -149,7 +144,6 @@ TEST_F(RoundTripTransferTest, R4F32_Large) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 47a39ffbbc42dedccc98694a23372cb064da752a..6ebd11584ff21abd05effe094b7ffbd7964c865e 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -212,9 +211,9 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); - std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); - std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + std::unique_ptr a_literal = Literal::CreateR0(2.1f); + std::unique_ptr b_literal = Literal::CreateR0(5.5f); + std::unique_ptr c_literal = Literal::CreateR0(0.5f); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -355,26 +354,25 @@ TEST_F(ScalarComputationsTest, DivU32s) { ComputationDataHandle divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Div(dividend, divisor); - TF_ASSIGN_OR_ASSERT_OK(div_computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build()); } for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = LiteralUtil::CreateR0(dividend); - auto divisor_literal = LiteralUtil::CreateR0(divisor); - TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, - client_->TransferToServer(*dividend_literal)); - TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + auto dividend_literal = Literal::CreateR0(dividend); + auto divisor_literal = Literal::CreateR0(divisor); + TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, + client_->TransferToServer(*dividend_literal)); + TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, + client_->TransferToServer(*divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(div_computation, {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = - LiteralUtil::CreateR0(dividend / divisor); + auto expected_literal = Literal::CreateR0(dividend / divisor); LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } } @@ -397,26 +395,25 @@ TEST_F(ScalarComputationsTest, RemU32s) { ComputationDataHandle divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Rem(dividend, divisor); - TF_ASSIGN_OR_ASSERT_OK(rem_computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build()); } for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = LiteralUtil::CreateR0(dividend); - auto divisor_literal = LiteralUtil::CreateR0(divisor); - TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, - client_->TransferToServer(*dividend_literal)); - TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + auto dividend_literal = Literal::CreateR0(dividend); + auto divisor_literal = Literal::CreateR0(divisor); + TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, + client_->TransferToServer(*dividend_literal)); + TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, + client_->TransferToServer(*divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(rem_computation, {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = - LiteralUtil::CreateR0(dividend % divisor); + auto expected_literal = Literal::CreateR0(dividend % divisor); LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } } @@ -428,8 +425,8 @@ TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); - std::unique_ptr literal = LiteralUtil::CreateR0(87919); - TF_ASSIGN_OR_ASSERT_OK(auto input_data, client_->TransferToServer(*literal)); + std::unique_ptr literal = Literal::CreateR0(87919); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } @@ -764,7 +761,7 @@ TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { TEST_F(ScalarComputationsTest, SqrtF320) { ComputationBuilder builder(client_, TestName()); - Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); + Literal zero_literal = Literal::Zero(PrimitiveType::F32); std::unique_ptr zero_data = client_->TransferToServer(zero_literal).ConsumeValueOrDie(); @@ -782,7 +779,6 @@ TEST_F(ScalarComputationsTest, SqrtF320) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 36110da2478083a45d5d378935278de42d55d221..de89588042ec097180906f49fb5b0c4b1fe16edd 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -381,7 +380,6 @@ XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 5eb4fee8ed28192a238efe2e6c9e1cad49a5f836..6b48116b6e1317eb23624242f1de656c3e7d48ca 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -262,7 +261,6 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc index 25bb915be56560e9a4eb0ebce990f488fe074241..38fc27f200ce823c2385d9456f8754dfccb1525e 100644 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ b/tensorflow/compiler/xla/tests/set_return_value_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -102,7 +101,6 @@ TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 70345c300cc778d9a52ffb857b8a1df2531e8d30..c77e892665b2254bba16c57382d07e38e50a9be7 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -33,91 +32,45 @@ limitations under the License. namespace xla { namespace { -class SliceTest : public ClientLibraryTestBase { - protected: - template - void RunSliceTenToTwo() { - std::vector constant; - constant.reserve(10); - for (int i = 0; i < 10; ++i) { - constant.push_back(static_cast(i)); - } - - ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(constant); - builder.Slice(original, {2}, {4}, {1}); - - const std::vector expected = {static_cast(2), - static_cast(3)}; - ComputeAndCompareR1(&builder, expected, {}); - } -}; - -XLA_TEST_F(SliceTest, SliceZeroToZeroF32) { - ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1({}); - builder.Slice(original, {0}, {0}, {1}); - - ComputeAndCompareR1(&builder, {}, {}); -} - -XLA_TEST_F(SliceTest, SliceTenToZeroF32) { - ComputationBuilder builder(client_, TestName()); - std::vector constant(10, 0.3); - auto original = builder.ConstantR1(constant); - builder.Slice(original, {7}, {7}, {1}); - - ComputeAndCompareR1(&builder, {}, {}); -} - -TEST_F(SliceTest, SliceTenToTwoF32) { RunSliceTenToTwo(); } - -XLA_TEST_F(SliceTest, SliceTenToTwoF64) { RunSliceTenToTwo(); } - -TEST_F(SliceTest, SliceTenToTwoU32) { RunSliceTenToTwo(); } - -TEST_F(SliceTest, SliceTenToTwoS32) { RunSliceTenToTwo(); } - -XLA_TEST_F(SliceTest, SliceTenToTwoU64) { RunSliceTenToTwo(); } - -XLA_TEST_F(SliceTest, SliceTenToTwoS64) { RunSliceTenToTwo(); } +class SliceTest : public ClientLibraryTestBase {}; -TEST_F(SliceTest, SliceTenToTen) { - const std::vector values = {0.0, 1.0, 2.0, 3.0, 4.0, - 5.0, 6.0, 7.0, 8.0, 9.0}; +TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { + Array3D values(3, 3, 3); + values.FillIota(0); ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(values); - builder.Slice(original, {0}, {10}, {1}); + auto original = builder.ConstantR3FromArray3D(values); + builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); - ComputeAndCompareR1(&builder, values, {}, ErrorSpec(0.000001)); + Array3D expected{ + {{0.0}, {3.0}, {6.0}}, {{9.0}, {12.0}, {15.0}}, {{18.0}, {21.0}, {24.0}}}; + ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.000001)); } -TEST_F(SliceTest, SliceLastFourOf1024) { - std::vector values(1024); - std::iota(values.begin(), values.end(), 0.0); +TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) { + Array3D values(3, 3, 3); + values.FillIota(0); ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(values); - builder.Slice(original, {1024 - 4}, {1024}, {1}); + auto original = builder.ConstantR3FromArray3D(values); + builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); - const std::vector expected = {1020, 1021, 1022, 1023}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.000001)); + Array3D expected{ + {{0.0, 1.0, 2.0}}, {{9.0, 10.0, 11.0}}, {{18.0, 19.0, 20.0}}}; + ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.000001)); } -// TODO(b/28491443): Fix wrong result on CPU and GPU. Failed on -// 2016-05-01. Also b/28508652 -TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) { - std::vector values(4096); - std::iota(values.begin(), values.end(), 0.0); +TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { + Array3D values(3, 3, 3); + values.FillIota(0); ComputationBuilder builder(client_, TestName()); - auto original = builder.ConstantR1(values); - builder.Slice(original, {7}, {7 + 1024}, {1}); + auto original = builder.ConstantR3FromArray3D(values); + builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); - std::vector expected(1024); - std::iota(values.begin(), values.end(), 7.0); - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.000001)); + Array3D expected{ + {{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}}}}; + ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.000001)); } XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { @@ -201,14 +154,78 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { Array4D values(2, 2, 24, 256); values.FillRandom(3.14f); - auto expected = - ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}); + auto expected = ReferenceUtil::Slice4D( + values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } +struct R1Spec { + int64 input_dim0; + int64 slice_start; + int64 slice_limit; + int64 slice_stride; +}; + +// Parameterized test that generates R1 values, slices them according +// to the R1Spec, and compares the result with a computed version. +class SliceR1Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + protected: + template + void Run(const R1Spec& spec) { + std::vector input(spec.input_dim0); + std::iota(input.begin(), input.end(), NativeT()); + + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR1(input); + builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, + {spec.slice_stride}); + + std::vector expected; + for (int i = spec.slice_start; i < spec.slice_limit; + i += spec.slice_stride) { + expected.push_back(i); + } + + ComputeAndCompareR1(&builder, expected, {}); + } +}; + +XLA_TEST_P(SliceR1Test, DoIt) { + Run(GetParam()); + Run(GetParam()); + Run(GetParam()); + Run(GetParam()); + Run(GetParam()); + Run(GetParam()); +} + +INSTANTIATE_TEST_CASE_P( // + SliceR1TestInstantiation, // + SliceR1Test, // + ::testing::Values( // + R1Spec{10, 0, 0, 1}, // + R1Spec{10, 7, 7, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 2, 4, 1}, // + R1Spec{10, 0, 10, 1}, // + R1Spec{1024, 1024 - 4, 1024, 1}, // + R1Spec{4096, 7, 7 + 1024, 1}, // + R1Spec{10, 0, 10, 2}, // + R1Spec{10, 0, 10, 3}, // + R1Spec{10, 0, 10, 4}, // + R1Spec{10, 0, 10, 5}, // + R1Spec{10, 0, 10, 10} // + ) // +); + struct R2Spec { int64 input_dim0; int64 input_dim1; @@ -223,17 +240,17 @@ struct R2Spec { class SliceR2Test : public ClientLibraryTestBase, public ::testing::WithParamInterface {}; -TEST_P(SliceR2Test, DoIt) { +XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(input); + auto a = builder.ConstantR2FromArray2DWithLayout(input, spec.layout); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); - std::unique_ptr> expected = - ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits); + std::unique_ptr> expected = ReferenceUtil::Slice2D( + input, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR2(&builder, *expected, {}); } @@ -258,6 +275,18 @@ INSTANTIATE_TEST_CASE_P( R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}}, LayoutUtil::MakeLayout({1, 0})}, R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, + LayoutUtil::MakeLayout({0, 1})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, + LayoutUtil::MakeLayout({0, 1})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, + LayoutUtil::MakeLayout({0, 1})}, + R2Spec {10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, LayoutUtil::MakeLayout({1, 0})} ) ); @@ -269,7 +298,6 @@ INSTANTIATE_TEST_CASE_P( int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 6a23df4d3c35a17a56b4ce816f79eaa642831f90..f3a522b05ebae4f1f86d6d7ddbac6e1749d3e286 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -61,7 +61,7 @@ std::unique_ptr CreateR2LiteralWithLayout( auto literal = MakeUnique(); const int64 d0 = values.size(); const int64 d1 = values.begin()->size(); - LiteralUtil::PopulateWithValue(0, {d0, d1}, literal.get()); + literal.get()->PopulateWithValue(0, {d0, d1}); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); @@ -70,7 +70,7 @@ std::unique_ptr CreateR2LiteralWithLayout( for (auto inner_list : values) { int64 dim1 = 0; for (auto value : inner_list) { - LiteralUtil::Set(literal.get(), {dim0, dim1}, value); + literal.get()->Set({dim0, dim1}, value); ++dim1; } ++dim0; @@ -88,7 +88,7 @@ std::unique_ptr CreateR3LiteralWithLayout( const int64 d0 = values.size(); const int64 d1 = values.begin()->size(); const int64 d2 = values.begin()->begin()->size(); - LiteralUtil::PopulateWithValue(0, {d0, d1, d2}, literal.get()); + literal.get()->PopulateWithValue(0, {d0, d1, d2}); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); @@ -99,7 +99,7 @@ std::unique_ptr CreateR3LiteralWithLayout( for (auto inner_inner_list : inner_list) { int64 dim2 = 0; for (auto value : inner_inner_list) { - LiteralUtil::Set(literal.get(), {dim0, dim1, dim2}, value); + literal.get()->Set({dim0, dim1, dim2}, value); ++dim2; } ++dim1; diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index e4951c4201060ae01d48f438bc462191de372f0e..07c0f073e86ee204a90b1f138c8c6d90a5c6936a 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -189,7 +188,6 @@ TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 6309e7129735aaaa81a14974ea52bd4cba219dc3..4a1c3fe9629218a0c3c8f5ccacd5500cedf73b61 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -40,6 +39,25 @@ class TupleTest : public ClientLibraryTestBase { ErrorSpec error_spec_{0.0001}; }; +// Tests a tuple-shaped constant. +XLA_TEST_F(TupleTest, TupleConstant) { + ComputationBuilder builder(client_, TestName()); + + const float constant_scalar = 7.3f; + std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; + std::initializer_list> constant_matrix = { + {1.1f, 2.2f, 3.5f}, // row 0 + {4.8f, 5.0f, 6.7f}, // row 1 + }; + auto value = + Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), + Literal::CreateR1(constant_vector).get(), + Literal::CreateR2(constant_matrix).get()}); + + auto result = builder.ConstantLiteral(*value); + ComputeAndCompareTuple(&builder, *value, {}, error_spec_); +} + // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreate) { ComputationBuilder builder(client_, TestName()); @@ -54,10 +72,10 @@ XLA_TEST_F(TupleTest, TupleCreate) { builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); + auto expected = + Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), + Literal::CreateR1(constant_vector).get(), + Literal::CreateR2(constant_matrix).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -68,9 +86,8 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { auto result = builder.Tuple( {builder.ConstantR0(7.0), builder.ConstantR1({})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), - LiteralUtil::CreateR1({}).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), + Literal::CreateR1({}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -78,7 +95,7 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XLA_TEST_F(TupleTest, EmptyTupleCreate) { ComputationBuilder builder(client_, TestName()); auto result = builder.Tuple({}); - auto expected = LiteralUtil::MakeTuple({}); + auto expected = Literal::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -147,12 +164,37 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { builder.ConstantR2(constant_matrix)}); auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1), builder.GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2(constant_matrix).get(), - LiteralUtil::CreateR1(constant_vector).get()}); + auto expected = + Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), + Literal::CreateR1(constant_vector).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } +XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle v1, v2; + + for (bool direction : {false, true}) { + std::unique_ptr v1_data = + CreateR0Parameter(0.0f, /*parameter_number=*/0, /*name=*/"v1", + /*builder=*/&b, /*data_handle=*/&v1); + std::unique_ptr v2_data = + CreateR0Parameter(1.0f, /*parameter_number=*/1, /*name=*/"v2", + /*builder=*/&b, /*data_handle=*/&v2); + auto v1_gt = b.Gt(v1, v2); // false + auto v2_gt = b.Gt(v2, v1); // true + auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} + auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} + auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + auto expected = + Literal::MakeTuple({Literal::CreateR0(direction).get(), + Literal::CreateR0(!direction).get()}); + + ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + error_spec_); + } +} + // Builds two new tuples from an existing tuple (by means of GetTupleElement), // then adds up the components of the new tuples. XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { @@ -213,9 +255,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), + Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -259,9 +300,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { auto select = builder.Select(builder.ConstantR0(true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), - LiteralUtil::CreateR1(vec2).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), + Literal::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -340,9 +380,8 @@ XLA_TEST_F(TupleTest, auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), + Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -353,13 +392,13 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto outer_tuple = builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); - auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); - auto expected_s = LiteralUtil::CreateR0(42.0); + auto expected_v1 = Literal::CreateR1({1.0, 2.0}); + auto expected_s = Literal::CreateR0(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); - auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); + Literal::MakeTuple({expected_v1.get(), expected_s.get()}); + auto expected_v2 = Literal::CreateR1({22.0, 44.0}); auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -379,14 +418,14 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( + ->TransferToServer(*Literal::MakeTuple({ + Literal::MakeTuple( { - LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), + Literal::CreateR1({1.0, 2.0, 3.0}).get(), + Literal::CreateR1({4.0, 5.0, 6.0}).get(), }) .get(), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + Literal::CreateR1({7.0, 8.0, 9.0}).get(), })) .ConsumeValueOrDie(); @@ -401,7 +440,6 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 61110d5b4cdaea62aa9844a195ee95698bf1632e..d35d9ecdeb6661ff5d5c8940a0e9dcc609aeb9a2 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -165,7 +164,6 @@ TEST_F(UnaryOpTest, SignAbsTestR2) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 26a08953b1534044058a001a8c9a66e6ab6461b0..079dbb06117949c870f89e1a3258e31463aa28ec 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -221,7 +220,6 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index efde45375fdbe8c0abbba0817f9d3062a118ab3c..b2e0c796bde46bac357635a0ab35dc521da7fde4 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -441,7 +440,6 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 5f9177977449561acaec6f480937833ea0de3dd1..8a6c40a0f570d9d979beaa2c1e915004d742675e 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -82,6 +81,70 @@ TEST_F(WhileTest, WhileWithScalarResult) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { + auto result_shape = ShapeUtil::MakeShape(S32, {}); + auto orig_shape = ShapeUtil::MakeShape(S32, {2}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0(5), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: add 1 to the result variable. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto input = builder.ConstantR0(1); + auto result = builder.Add(input, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.Reduce(builder.ConstantR1(2, 1), + builder.ConstantR0(0), + CreateScalarAddComputation(S32, &builder), {0}); + auto result = builder.While(condition, body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + + ComputeAndCompareR0(&builder, 5, {}); +} + +TEST_F(WhileTest, WhileWithPredicateResult) { + auto result_shape = ShapeUtil::MakeShape(PRED, {}); + + // Create a computation for the condition: run until condition is true. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Ne(builder.ConstantR0(true), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: or condition with true. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto result = builder.LogicalOr(prev, builder.ConstantR0(true)); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.Ne(builder.ConstantR0(false), + builder.ConstantR0(true)); + auto result = builder.While(condition, body, init); + + ComputeAndCompareR0(&builder, true, {}); +} + // Tests a while node when the result type T is a vector. // // All constants are chosen to produce exact results. @@ -240,15 +303,62 @@ TEST_F(WhileTest, WhileWithTupleResult) { VLOG(2) << "while = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = LiteralUtil::CreateR0(5); - auto expected_data = LiteralUtil::CreateR1( + auto expected_counter = Literal::CreateR0(5); + auto expected_data = Literal::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +TEST_F(WhileTest, WhileWithPredicateTupleResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(PRED, {})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and or the predicate with true + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto pred = builder.GetTupleElement(prev, 1); + auto new_pred = builder.LogicalOr(pred, builder.ConstantR0(true)); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple({builder.ConstantR0(0), + builder.Ne(builder.ConstantR0(false), + builder.ConstantR0(true))}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(5); + auto expected_predicate = Literal::CreateR0(true); + auto expected = + Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); +} + // Tests two while nodes when the result type T is a Tuple and the second // while node uses the result of the first while node which is used in two // nodes. @@ -277,7 +387,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); - TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } Computation condition2; @@ -287,7 +397,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); - TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } // Create a computation for the body. @@ -303,7 +413,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); - TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } Computation body2; @@ -316,7 +426,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); - TF_ASSIGN_OR_ASSERT_OK(body2, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. @@ -356,7 +466,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); - TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } Computation condition2; @@ -366,7 +476,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); - TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } // Create a computation for the body. @@ -382,7 +492,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); - TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. @@ -423,7 +533,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); - TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } Computation condition2; @@ -433,7 +543,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); - TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } // Create a computation for the body. @@ -449,7 +559,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); - TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. @@ -525,11 +635,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = LiteralUtil::CreateR0(5); - auto expected_data = LiteralUtil::CreateR1( + auto expected_counter = Literal::CreateR0(5); + auto expected_data = Literal::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -587,11 +697,11 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { }; for (int i = 1; i < 4; ++i) { - TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i)); + TF_ASSERT_OK_AND_ASSIGN(auto computation, while_loop(i)); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(65); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); } @@ -743,7 +853,6 @@ BENCHMARK(BM_WhileLoop); int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 7876272467890b56c2cca71f64e66303eb8ac632..4d060895d357493327ec50b38016478c65fef94d 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -104,8 +104,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { auto result = MakeUnique(); const float fill = std::numeric_limits::quiet_NaN(); - LiteralUtil::PopulateWithValue(fill, AsInt64Slice(shape.dimensions()), - result.get()); + result->PopulateWithValue(fill, AsInt64Slice(shape.dimensions())); std::vector pieces; std::vector coordinates; std::vector coordinate_values; @@ -147,7 +146,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { "\"%s\"", shape.dimensions_size(), coordinate_values.size(), line.c_str()); } - LiteralUtil::Set(result.get(), coordinate_values, value); + result->Set(coordinate_values, value); } return std::move(result); } diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index a167d80f73b0273739e22d94be8d90ab00839dc9..23070b663870a2b78b38663e09a32fcb28d9c2dc 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -46,12 +46,12 @@ TEST(TextLiteralReaderTest, ReadsR3File) { TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); - EXPECT_EQ(42.5, LiteralUtil::Get(*literal, {0, 0, 0})); - EXPECT_EQ(43.5, LiteralUtil::Get(*literal, {0, 0, 1})); - EXPECT_EQ(44.5, LiteralUtil::Get(*literal, {0, 0, 2})); - EXPECT_EQ(45.5, LiteralUtil::Get(*literal, {0, 1, 0})); - EXPECT_EQ(46.5, LiteralUtil::Get(*literal, {0, 1, 1})); - EXPECT_EQ(47.5, LiteralUtil::Get(*literal, {0, 1, 2})); + EXPECT_EQ(42.5, literal->Get({0, 0, 0})); + EXPECT_EQ(43.5, literal->Get({0, 0, 1})); + EXPECT_EQ(44.5, literal->Get({0, 0, 2})); + EXPECT_EQ(45.5, literal->Get({0, 1, 0})); + EXPECT_EQ(46.5, literal->Get({0, 1, 1})); + EXPECT_EQ(47.5, literal->Get({0, 1, 2})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index a5097e41cb3cb3fe1c10e3c21c00c2242087deba..3fee467594d8423c707abf07a0622a738437830a 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -45,9 +45,9 @@ namespace xla { tensorflow::Status status; tensorflow::WritableFile* f_ptr = f.get(); - LiteralUtil::EachCellAsString( - literal, [f_ptr, &status](tensorflow::gtl::ArraySlice indices, - const string& value) { + literal.EachCellAsString( + [f_ptr, &status](tensorflow::gtl::ArraySlice indices, + const string& value) { if (!status.ok()) { return; } diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 177ae4ea036af660b7a2be1d4082b30ca8fb9fac..70cf2fb1b8a1b4f2ecfdaeaef3a00ddc974e2652 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -30,7 +30,7 @@ namespace xla { namespace { TEST(TextLiteralWriterTest, WritesFloatLiteral) { - auto literal = LiteralUtil::CreateR2({ + auto literal = Literal::CreateR2({ {3.14, 2.17}, {1.23, 4.56}, }); string path = diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 535e5b605b4f68671c9b6a8af4a12732f88e744e..4bbe0ba0ddd93b59557d3a4c6007ed9d2f8b7c11 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,7 +36,7 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", @@ -187,7 +187,7 @@ cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:session_proto", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 10efa9f3e8d856493b2db23195188da6fba65244..7861c3a9b72e85cba8907c82a9d36d0fe39889c2 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -32,7 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -53,8 +53,12 @@ void RealMain(tensorflow::gtl::ArraySlice args) { TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_generate_hlo_graph(".*"); + debug_options.set_xla_hlo_graph_layout(true); ComputationStats stats = - client->GetComputationStats(computation).ConsumeValueOrDie(); + client->GetComputationStats(computation, debug_options) + .ConsumeValueOrDie(); fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); } } @@ -63,12 +67,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } // namespace xla int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } tensorflow::port::InitMain(argv[0], &argc, &argv); - xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); - flags->xla_generate_hlo_graph = ".*"; - flags->xla_hlo_graph_layout = true; - tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] xla::tools::RealMain(args); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 850267d3195785a96bf8d2c80fe64fdb8aae0a91..51f90b07c66f7d839f587350726333b9dbe6a9f0 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -52,8 +52,12 @@ void RealMain(tensorflow::gtl::ArraySlice args) { TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_generate_hlo_graph(".*"); + debug_options.set_xla_hlo_dump_as_graphdef(true); ComputationStats stats = - client->GetComputationStats(computation).ConsumeValueOrDie(); + client->GetComputationStats(computation, debug_options) + .ConsumeValueOrDie(); fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); } } @@ -62,14 +66,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } // namespace xla int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); - - xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); - flags->xla_generate_hlo_graph = ".*"; + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } - xla::legacy_flags::HloGraphDumperFlags* dumper_flags = - xla::legacy_flags::GetHloGraphDumperFlags(); - dumper_flags->xla_hlo_dump_as_graphdef = true; + tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 3a75bf6495415e569aafce1eccc843cc95f9f7fa..6228ca34c0835a7476e45037c9bb6373ee1750dd 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -98,11 +98,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { std::unique_ptr result = result_status.ConsumeValueOrDie(); fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), ShapeUtil::HumanString(result->shape()).c_str(), - LiteralUtil::ToString(*result).c_str()); + result->ToString().c_str()); if (module.has_result()) { fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(module.result().shape()).c_str(), - LiteralUtil::ToString(Literal(module.result())).c_str()); + Literal(module.result()).ToString().c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index b6538f5de07743ef7320343d6b23119e919d114f..b50cb5e28eac14ed99af566939f8bd64e393ff64 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -42,5 +42,5 @@ int main(int argc, char **argv) { &literal_proto)); xla::Literal literal(literal_proto); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str()); + fprintf(stderr, "%s\n", literal.ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 2d983b407c64ab5547d722abcc2c564a7963f730..bbe9902aa17a585c4bad5b732330305dfdd45302 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -40,7 +40,7 @@ int main(int argc, char **argv) { xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal->ShortDebugString(); - fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(*literal).c_str()); + fprintf(stderr, "%s\n", literal->ToString().c_str()); if (literal->shape().element_type() == xla::F32) { float min = *std::min_element(literal->f32s().begin(), literal->f32s().end()); diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index d467178cb528a93b2c1030fc72d054cc0edf95b6..1ecdb9852d84175dbe30878022519cd62f54747c 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -15,9 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" +#include #include +#include -#include "tensorflow/compiler/xla/legacy_flags/util_flags.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -30,18 +31,12 @@ limitations under the License. namespace xla { namespace { -// Adds a backtrace to the provided status iff the xla_status_add_backtrace flag -// is set. This is useful for quickly tracing status errors observed coming out -// of the service. -Status MaybeAddBacktrace(const Status& prior) { - DCHECK(!prior.ok()); - if (legacy_flags::GetUtilFlags()->xla_status_add_backtrace) { - return Status{prior.code(), - tensorflow::strings::StrCat(prior.error_message(), " :: ", - tensorflow::CurrentStackTrace())}; - } else { - return prior; - } +// Logs the provided status message with a backtrace. +Status WithLogBacktrace(const Status& status) { + CHECK(!status.ok()); + VLOG(1) << status.ToString(); + VLOG(1) << tensorflow::CurrentStackTrace(); + return status; } } // namespace @@ -84,7 +79,7 @@ Status InvalidArgument(const char* format, ...) { va_start(args, format); tensorflow::strings::Appendv(&message, format, args); va_end(args); - return MaybeAddBacktrace(tensorflow::errors::InvalidArgument(message)); + return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); } Status Unimplemented(const char* format, ...) { @@ -93,7 +88,7 @@ Status Unimplemented(const char* format, ...) { va_start(args, format); tensorflow::strings::Appendv(&message, format, args); va_end(args); - return MaybeAddBacktrace(tensorflow::errors::Unimplemented(message)); + return WithLogBacktrace(tensorflow::errors::Unimplemented(message)); } Status InternalError(const char* format, ...) { @@ -102,7 +97,7 @@ Status InternalError(const char* format, ...) { va_start(args, format); tensorflow::strings::Appendv(&message, format, args); va_end(args); - return MaybeAddBacktrace(tensorflow::errors::Internal(message)); + return WithLogBacktrace(tensorflow::errors::Internal(message)); } Status FailedPrecondition(const char* format, ...) { @@ -111,7 +106,7 @@ Status FailedPrecondition(const char* format, ...) { va_start(args, format); tensorflow::strings::Appendv(&message, format, args); va_end(args); - return MaybeAddBacktrace(tensorflow::errors::FailedPrecondition(message)); + return WithLogBacktrace(tensorflow::errors::FailedPrecondition(message)); } Status ResourceExhausted(const char* format, ...) { @@ -120,7 +115,7 @@ Status ResourceExhausted(const char* format, ...) { va_start(args, format); tensorflow::strings::Appendv(&message, format, args); va_end(args); - return MaybeAddBacktrace(tensorflow::errors::ResourceExhausted(message)); + return WithLogBacktrace(tensorflow::errors::ResourceExhausted(message)); } Status NotFound(const char* format, ...) { @@ -129,7 +124,7 @@ Status NotFound(const char* format, ...) { va_start(args, format); tensorflow::strings::Appendv(&message, format, args); va_end(args); - return MaybeAddBacktrace(tensorflow::errors::NotFound(message)); + return WithLogBacktrace(tensorflow::errors::NotFound(message)); } Status Unavailable(const char* format, ...) { @@ -138,7 +133,7 @@ Status Unavailable(const char* format, ...) { va_start(args, format); tensorflow::strings::Appendv(&message, format, args); va_end(args); - return MaybeAddBacktrace(tensorflow::errors::Unavailable(message)); + return WithLogBacktrace(tensorflow::errors::Unavailable(message)); } string Reindent(tensorflow::StringPiece original, diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 42d5c1d15501fb912551a044414e6fa0c83283b8..00151e5da6b7aae79793ccb0f3df49531b417aa9 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -195,16 +195,24 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); // 2. permutation.size() == input.size(). template - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts deleted file mode 100644 index 939300f3878e6c09551c77062a94a92d3cc07000..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {PointMetadata} from './data'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -// tslint:disable-next-line -export let MetadataCardPolymer = PolymerElement({ - is: 'vz-projector-metadata-card', - properties: { - hasMetadata: {type: Boolean, value: false}, - metadata: {type: Array}, - label: String - } -}); - -export class MetadataCard extends MetadataCardPolymer { - hasMetadata: boolean; - metadata: Array<{key: string, value: string}>; - label: string; - - private labelOption: string; - private pointMetadata: PointMetadata; - - private expandLessButton: HTMLButtonElement; - private expandMoreButton: HTMLButtonElement; - - ready() { - this.expandLessButton = - this.querySelector('#expand-less') as HTMLButtonElement; - this.expandMoreButton = - this.querySelector('#expand-more') as HTMLButtonElement; - } - /** Handles a click on the expand more icon. */ - _expandMore() { - (this.$$('#metadata-container') as any).toggle(); - - this.expandMoreButton.style.display = 'none'; - this.expandLessButton.style.display = ''; - } - - /** Handles a click on the expand less icon. */ - _expandLess() { - (this.$$('#metadata-container') as any).toggle(); - this.expandMoreButton.style.display = ''; - this.expandLessButton.style.display = 'none'; - } - - updateMetadata(pointMetadata?: PointMetadata) { - this.pointMetadata = pointMetadata; - this.hasMetadata = (pointMetadata != null); - - if (pointMetadata) { - let metadata = []; - for (let metadataKey in pointMetadata) { - if (!pointMetadata.hasOwnProperty(metadataKey)) { - continue; - } - metadata.push({key: metadataKey, value: pointMetadata[metadataKey]}); - } - - this.metadata = metadata; - this.label = '' + this.pointMetadata[this.labelOption]; - } - } - - setLabelOption(labelOption: string) { - this.labelOption = labelOption; - if (this.pointMetadata) { - this.label = '' + this.pointMetadata[this.labelOption]; - } - } -} - -document.registerElement(MetadataCard.prototype.is, MetadataCard); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html deleted file mode 100644 index b82f3f520b5e62bb381f1a9c6ebd10c4a04d13cf..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html +++ /dev/null @@ -1,316 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts deleted file mode 100644 index 377c6c11ad5d19343682540bdadc3319b5d0ee3c..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts +++ /dev/null @@ -1,589 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as data from './data'; -import {DataSet, Projection, ProjectionType, SpriteAndMetadataInfo, State} from './data'; -import * as util from './util'; -import * as vector from './vector'; -import {Vector} from './vector'; -import {Projector} from './vz-projector'; -import {ProjectorInput} from './vz-projector-input'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -const NUM_PCA_COMPONENTS = 10; - -// tslint:disable-next-line -export let ProjectionsPanelPolymer = PolymerElement({ - is: 'vz-projector-projections-panel', - properties: { - pcaIs3d: - {type: Boolean, value: true, observer: '_pcaDimensionToggleObserver'}, - tSNEis3d: - {type: Boolean, value: true, observer: '_tsneDimensionToggleObserver'}, - // PCA projection. - pcaComponents: Array, - pcaX: {type: Number, value: 0, observer: 'showPCAIfEnabled'}, - pcaY: {type: Number, value: 1, observer: 'showPCAIfEnabled'}, - pcaZ: {type: Number, value: 2, observer: 'showPCAIfEnabled'}, - // Custom projection. - customSelectedSearchByMetadataOption: { - type: String, - observer: '_customSelectedSearchByMetadataOptionChanged' - }, - } -}); - -type InputControlName = 'xLeft'|'xRight'|'yUp'|'yDown'; - -type CentroidResult = { - centroid?: Vector; numMatches?: number; -}; - -type Centroids = { - [key: string]: Vector; xLeft: Vector; xRight: Vector; yUp: Vector; - yDown: Vector; -}; - -/** - * A polymer component which handles the projection tabs in the projector. - */ -export class ProjectionsPanel extends ProjectionsPanelPolymer { - private projector: Projector; - private pcaComponents: - Array<{id: number, componentNumber: number, percVariance: string}>; - private currentProjection: ProjectionType; - private polymerChangesTriggerReprojection: boolean; - private dataSet: DataSet; - private originalDataSet: DataSet; - private dim: number; - - /** T-SNE perplexity. Roughly how many neighbors each point influences. */ - private perplexity: number; - /** T-SNE learning rate. */ - private learningRate: number; - - private searchByMetadataOptions: string[]; - - /** Centroids for custom projections. */ - private centroidValues: any; - private centroids: Centroids; - /** The centroid across all points. */ - private allCentroid: number[]; - - /** Polymer properties. */ - // TODO(nsthorat): Move these to a separate view controller. - public tSNEis3d: boolean; - public pcaIs3d: boolean; - public pcaX: number; - public pcaY: number; - public pcaZ: number; - public customSelectedSearchByMetadataOption: string; - - /** Polymer elements. */ - private runTsneButton: HTMLButtonElement; - private stopTsneButton: HTMLButtonElement; - private perplexitySlider: HTMLInputElement; - private learningRateInput: HTMLInputElement; - private zDropdown: HTMLElement; - private iterationLabel: HTMLElement; - - private customProjectionXLeftInput: ProjectorInput; - private customProjectionXRightInput: ProjectorInput; - private customProjectionYUpInput: ProjectorInput; - private customProjectionYDownInput: ProjectorInput; - - initialize(projector: Projector) { - this.polymerChangesTriggerReprojection = true; - this.projector = projector; - - // Set up TSNE projections. - this.perplexity = 30; - this.learningRate = 10; - - // Setup Custom projections. - this.centroidValues = {xLeft: null, xRight: null, yUp: null, yDown: null}; - this.clearCentroids(); - - this.setupUIControls(); - } - - ready() { - this.zDropdown = this.querySelector('#z-dropdown') as HTMLElement; - this.runTsneButton = this.querySelector('.run-tsne') as HTMLButtonElement; - this.stopTsneButton = this.querySelector('.stop-tsne') as HTMLButtonElement; - this.perplexitySlider = - this.querySelector('#perplexity-slider') as HTMLInputElement; - this.learningRateInput = - this.querySelector('#learning-rate-slider') as HTMLInputElement; - this.iterationLabel = this.querySelector('.run-tsne-iter') as HTMLElement; - } - - disablePolymerChangesTriggerReprojection() { - this.polymerChangesTriggerReprojection = false; - } - - enablePolymerChangesTriggerReprojection() { - this.polymerChangesTriggerReprojection = true; - } - - private updateTSNEPerplexityFromSliderChange() { - if (this.perplexitySlider) { - this.perplexity = +this.perplexitySlider.value; - } - (this.querySelector('.tsne-perplexity span') as HTMLSpanElement).innerText = - '' + this.perplexity; - } - - private updateTSNELearningRateFromUIChange() { - if (this.learningRateInput) { - this.learningRate = Math.pow(10, +this.learningRateInput.value); - } - (this.querySelector('.tsne-learning-rate span') as HTMLSpanElement) - .innerText = '' + this.learningRate; - } - - private setupUIControls() { - { - const self = this; - const inkTabs = this.querySelectorAll('.ink-tab'); - for (let i = 0; i < inkTabs.length; i++) { - inkTabs[i].addEventListener('click', function() { - let id = this.getAttribute('data-tab'); - self.showTab(id); - }); - } - } - - this.runTsneButton.addEventListener('click', () => this.runTSNE()); - this.stopTsneButton.addEventListener( - 'click', () => this.dataSet.stopTSNE()); - - this.perplexitySlider.value = this.perplexity.toString(); - this.perplexitySlider.addEventListener( - 'change', () => this.updateTSNEPerplexityFromSliderChange()); - this.updateTSNEPerplexityFromSliderChange(); - - this.learningRateInput.addEventListener( - 'change', () => this.updateTSNELearningRateFromUIChange()); - this.updateTSNELearningRateFromUIChange(); - - this.setupCustomProjectionInputFields(); - // TODO: figure out why `--paper-input-container-input` css mixin didn't - // work. - const inputs = - this.querySelectorAll('paper-dropdown-menu paper-input input'); - for (let i = 0; i < inputs.length; i++) { - (inputs[i] as HTMLElement).style.fontSize = '14px'; - } - } - - restoreUIFromBookmark(bookmark: State) { - this.disablePolymerChangesTriggerReprojection(); - - // PCA - this.pcaX = bookmark.pcaComponentDimensions[0]; - this.pcaY = bookmark.pcaComponentDimensions[1]; - if (bookmark.pcaComponentDimensions.length === 3) { - this.pcaZ = bookmark.pcaComponentDimensions[2]; - } - this.pcaIs3d = (bookmark.pcaComponentDimensions.length === 3); - - // t-SNE - if (this.perplexitySlider) { - this.perplexitySlider.value = bookmark.tSNEPerplexity.toString(); - } - if (this.learningRateInput) { - this.learningRateInput.value = bookmark.tSNELearningRate.toString(); - } - this.tSNEis3d = bookmark.tSNEis3d; - - // custom - this.customSelectedSearchByMetadataOption = - bookmark.customSelectedSearchByMetadataOption; - if (this.customProjectionXLeftInput) { - this.customProjectionXLeftInput.set( - bookmark.customXLeftText, bookmark.customXLeftRegex); - } - if (this.customProjectionXRightInput) { - this.customProjectionXRightInput.set( - bookmark.customXRightText, bookmark.customXRightRegex); - } - if (this.customProjectionYUpInput) { - this.customProjectionYUpInput.set( - bookmark.customYUpText, bookmark.customYUpRegex); - } - if (this.customProjectionYDownInput) { - this.customProjectionYDownInput.set( - bookmark.customYDownText, bookmark.customYDownRegex); - } - this.computeAllCentroids(); - - this.setZDropdownEnabled(this.pcaIs3d); - this.updateTSNEPerplexityFromSliderChange(); - this.updateTSNELearningRateFromUIChange(); - if (this.iterationLabel) { - this.iterationLabel.innerText = bookmark.tSNEIteration.toString(); - } - if (bookmark.selectedProjection != null) { - this.showTab(bookmark.selectedProjection); - } - this.enablePolymerChangesTriggerReprojection(); - } - - populateBookmarkFromUI(bookmark: State) { - this.disablePolymerChangesTriggerReprojection(); - - // PCA - bookmark.pcaComponentDimensions = [this.pcaX, this.pcaY]; - if (this.pcaIs3d) { - bookmark.pcaComponentDimensions.push(this.pcaZ); - } - - // t-SNE - if (this.perplexitySlider != null) { - bookmark.tSNEPerplexity = +this.perplexitySlider.value; - } - if (this.learningRateInput != null) { - bookmark.tSNELearningRate = +this.learningRateInput.value; - } - bookmark.tSNEis3d = this.tSNEis3d; - - // custom - bookmark.customSelectedSearchByMetadataOption = - this.customSelectedSearchByMetadataOption; - if (this.customProjectionXLeftInput != null) { - bookmark.customXLeftText = this.customProjectionXLeftInput.getValue(); - bookmark.customXLeftRegex = - this.customProjectionXLeftInput.getInRegexMode(); - } - if (this.customProjectionXRightInput != null) { - bookmark.customXRightText = this.customProjectionXRightInput.getValue(); - bookmark.customXRightRegex = - this.customProjectionXRightInput.getInRegexMode(); - } - if (this.customProjectionYUpInput != null) { - bookmark.customYUpText = this.customProjectionYUpInput.getValue(); - bookmark.customYUpRegex = this.customProjectionYUpInput.getInRegexMode(); - } - if (this.customProjectionYDownInput != null) { - bookmark.customYDownText = this.customProjectionYDownInput.getValue(); - bookmark.customYDownRegex = - this.customProjectionYDownInput.getInRegexMode(); - } - - this.enablePolymerChangesTriggerReprojection(); - } - - // This method is marked as public as it is used as the view method that - // abstracts DOM manipulation so we can stub it in a test. - // TODO(nsthorat): Move this to its own class as the glue between this class - // and the DOM. - setZDropdownEnabled(enabled: boolean) { - if (this.zDropdown) { - if (enabled) { - this.zDropdown.removeAttribute('disabled'); - } else { - this.zDropdown.setAttribute('disabled', 'true'); - } - } - } - - dataSetUpdated(dataSet: DataSet, originalDataSet: DataSet, dim: number) { - this.dataSet = dataSet; - this.originalDataSet = originalDataSet; - this.dim = dim; - const pointCount = (dataSet == null) ? 0 : dataSet.points.length; - const perplexity = Math.max(5, Math.ceil(Math.sqrt(pointCount) / 4)); - this.perplexitySlider.value = perplexity.toString(); - this.updateTSNEPerplexityFromSliderChange(); - this.clearCentroids(); - - (this.querySelector('#tsne-sampling') as HTMLElement).style.display = - pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none'; - const wasSampled = - (dataSet == null) ? false : (dataSet.dim[0] > data.PCA_SAMPLE_DIM || - dataSet.dim[1] > data.PCA_SAMPLE_DIM); - (this.querySelector('#pca-sampling') as HTMLElement).style.display = - wasSampled ? null : 'none'; - this.showTab('pca'); - } - - _pcaDimensionToggleObserver() { - this.setZDropdownEnabled(this.pcaIs3d); - this.beginProjection(this.currentProjection); - } - - _tsneDimensionToggleObserver() { - this.beginProjection(this.currentProjection); - } - - metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) { - // Project by options for custom projections. - let searchByMetadataIndex = -1; - this.searchByMetadataOptions = spriteAndMetadata.stats.map((stats, i) => { - // Make the default label by the first non-numeric column. - if (!stats.isNumeric && searchByMetadataIndex === -1) { - searchByMetadataIndex = i; - } - return stats.name; - }); - this.customSelectedSearchByMetadataOption = - this.searchByMetadataOptions[Math.max(0, searchByMetadataIndex)]; - } - - public showTab(id: ProjectionType) { - this.currentProjection = id; - - const tab = - this.querySelector('.ink-tab[data-tab="' + id + '"]') as HTMLElement; - const allTabs = this.querySelectorAll('.ink-tab'); - for (let i = 0; i < allTabs.length; i++) { - util.classed(allTabs[i] as HTMLElement, 'active', false); - } - - util.classed(tab, 'active', true); - - const allTabContent = this.querySelectorAll('.ink-panel-content'); - for (let i = 0; i < allTabContent.length; i++) { - util.classed(allTabContent[i] as HTMLElement, 'active', false); - } - - util.classed( - this.querySelector('.ink-panel-content[data-panel="' + id + '"]') as - HTMLElement, - 'active', true); - - // guard for unit tests, where polymer isn't attached and $ doesn't exist. - if (this.$ != null) { - const main = this.$['main']; - // In order for the projections panel to animate its height, we need to - // set it explicitly. - requestAnimationFrame(() => { - this.style.height = main.clientHeight + 'px'; - }); - } - - this.beginProjection(id); - } - - private beginProjection(projection: ProjectionType) { - if (this.polymerChangesTriggerReprojection === false) { - return; - } - if (projection === 'pca') { - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - this.showPCA(); - } else if (projection === 'tsne') { - this.showTSNE(); - } else if (projection === 'custom') { - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - this.computeAllCentroids(); - this.reprojectCustom(); - } - } - - private showTSNE() { - const dataSet = this.dataSet; - if (dataSet == null) { - return; - } - const accessors = - data.getProjectionComponents('tsne', [0, 1, this.tSNEis3d ? 2 : null]); - const dimensionality = this.tSNEis3d ? 3 : 2; - const projection = - new Projection('tsne', accessors, dimensionality, dataSet); - this.projector.setProjection(projection); - - if (!this.dataSet.hasTSNERun) { - this.runTSNE(); - } else { - this.projector.notifyProjectionPositionsUpdated(); - } - } - - private runTSNE() { - this.runTsneButton.disabled = true; - this.stopTsneButton.disabled = null; - this.dataSet.projectTSNE( - this.perplexity, this.learningRate, this.tSNEis3d ? 3 : 2, - (iteration: number) => { - if (iteration != null) { - this.iterationLabel.innerText = '' + iteration; - this.projector.notifyProjectionPositionsUpdated(); - } else { - this.runTsneButton.disabled = null; - this.stopTsneButton.disabled = true; - } - }); - } - - // tslint:disable-next-line:no-unused-variable - private showPCAIfEnabled() { - if (this.polymerChangesTriggerReprojection) { - this.showPCA(); - } - } - - private updateTotalVarianceMessage() { - let variances = this.dataSet.fracVariancesExplained; - let totalVariance = variances[this.pcaX] + variances[this.pcaY]; - let msg = 'Total variance described: '; - if (this.pcaIs3d) { - totalVariance += variances[this.pcaZ]; - } - msg += (totalVariance * 100).toFixed(1) + '%.'; - (this.querySelector('#total-variance') as HTMLElement).innerHTML = msg; - } - - private showPCA() { - if (this.dataSet == null) { - return; - } - this.dataSet.projectPCA().then(() => { - // Polymer properties are 1-based. - const accessors = data.getProjectionComponents( - 'pca', [this.pcaX, this.pcaY, this.pcaZ]); - - const dimensionality = this.pcaIs3d ? 3 : 2; - const projection = - new Projection('pca', accessors, dimensionality, this.dataSet); - this.projector.setProjection(projection); - let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]); - this.updateTotalVarianceMessage(); - this.pcaComponents = util.range(numComponents).map(i => { - let fracVariance = this.dataSet.fracVariancesExplained[i]; - return { - id: i, - componentNumber: i + 1, - percVariance: (fracVariance * 100).toFixed(1) - }; - }); - }); - } - - private reprojectCustom() { - if (this.centroids == null || this.centroids.xLeft == null || - this.centroids.xRight == null || this.centroids.yUp == null || - this.centroids.yDown == null) { - return; - } - const xDir = vector.sub(this.centroids.xRight, this.centroids.xLeft); - this.dataSet.projectLinear(xDir, 'linear-x'); - - const yDir = vector.sub(this.centroids.yUp, this.centroids.yDown); - this.dataSet.projectLinear(yDir, 'linear-y'); - - const accessors = data.getProjectionComponents('custom', ['x', 'y']); - const projection = new Projection('custom', accessors, 2, this.dataSet); - this.projector.setProjection(projection); - } - - clearCentroids(): void { - this.centroids = {xLeft: null, xRight: null, yUp: null, yDown: null}; - this.allCentroid = null; - } - - _customSelectedSearchByMetadataOptionChanged(newVal: string, oldVal: string) { - if (this.polymerChangesTriggerReprojection === false) { - return; - } - if (this.currentProjection === 'custom') { - this.computeAllCentroids(); - this.reprojectCustom(); - } - } - - private setupCustomProjectionInputFields() { - this.customProjectionXLeftInput = - this.setupCustomProjectionInputField('xLeft'); - this.customProjectionXRightInput = - this.setupCustomProjectionInputField('xRight'); - this.customProjectionYUpInput = this.setupCustomProjectionInputField('yUp'); - this.customProjectionYDownInput = - this.setupCustomProjectionInputField('yDown'); - } - - private computeAllCentroids() { - this.computeCentroid('xLeft'); - this.computeCentroid('xRight'); - this.computeCentroid('yUp'); - this.computeCentroid('yDown'); - } - - private computeCentroid(name: InputControlName) { - const input = this.querySelector('#' + name) as ProjectorInput; - if (input == null) { - return; - } - const value = input.getValue(); - if (value == null) { - return; - } - let inRegexMode = input.getInRegexMode(); - let result = this.getCentroid(value, inRegexMode); - if (result.numMatches === 0) { - input.message = '0 matches. Using a random vector.'; - result.centroid = vector.rn(this.dim); - } else { - input.message = `${result.numMatches} matches.`; - } - this.centroids[name] = result.centroid; - this.centroidValues[name] = value; - } - - private setupCustomProjectionInputField(name: InputControlName): - ProjectorInput { - let input = this.querySelector('#' + name) as ProjectorInput; - input.registerInputChangedListener((input, inRegexMode) => { - if (this.polymerChangesTriggerReprojection) { - this.computeCentroid(name); - this.reprojectCustom(); - } - }); - return input; - } - - private getCentroid(pattern: string, inRegexMode: boolean): CentroidResult { - if (pattern == null || pattern === '') { - return {numMatches: 0}; - } - // Search by the original dataset since we often want to filter and project - // only the nearest neighbors of A onto B-C where B and C are not nearest - // neighbors of A. - let accessor = (i: number) => this.originalDataSet.points[i].vector; - let r = this.originalDataSet.query( - pattern, inRegexMode, this.customSelectedSearchByMetadataOption); - return {centroid: vector.centroid(r, accessor), numMatches: r.length}; - } - - getPcaSampledDimText() { - return data.PCA_SAMPLE_DIM.toLocaleString(); - } - - getPcaSampleSizeText() { - return data.PCA_SAMPLE_SIZE.toLocaleString(); - } - - getTsneSampleSizeText() { - return data.TSNE_SAMPLE_SIZE.toLocaleString(); - } -} - -document.registerElement(ProjectionsPanel.prototype.is, ProjectionsPanel); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts deleted file mode 100644 index 44062062a364b742e2de6467614e508d4e89d37a..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -export type Spec = { - is: string; properties?: { - [key: string]: - (Function | - { - type: Function, value?: any; - readonly?: boolean; - notify?: boolean; - observer?: string; - }) - }; - observers?: string[]; -}; - -export function PolymerElement(spec: Spec) { - return Polymer.Class(spec as any) as{new (): PolymerHTMLElement}; -} - -export interface PolymerHTMLElement extends HTMLElement, polymer.Base {} diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.html b/tensorflow/tensorboard/components/vz_projector/vz-projector.html deleted file mode 100644 index 438ea9f4e978fa608eb0cabde35e9adf6f7e87fe..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.html +++ /dev/null @@ -1,346 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts deleted file mode 100644 index bf98a4d478599f7b859e893e7a17567f22fd5114..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts +++ /dev/null @@ -1,570 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {AnalyticsLogger} from './analyticsLogger'; -import * as data from './data'; -import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data'; -import {DataProvider, EmbeddingInfo, ServingMode} from './data-provider'; -import {DemoDataProvider} from './data-provider-demo'; -import {ProtoDataProvider} from './data-provider-proto'; -import {ServerDataProvider} from './data-provider-server'; -import * as knn from './knn'; -import * as logging from './logging'; -import {DistanceMetricChangedListener, HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext'; -import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter'; -import {MouseMode} from './scatterPlot'; -import * as util from './util'; -import {BookmarkPanel} from './vz-projector-bookmark-panel'; -import {DataPanel} from './vz-projector-data-panel'; -import {InspectorPanel} from './vz-projector-inspector-panel'; -import {MetadataCard} from './vz-projector-metadata-card'; -import {ProjectionsPanel} from './vz-projector-projections-panel'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -/** - * The minimum number of dimensions the data should have to automatically - * decide to normalize the data. - */ -const THRESHOLD_DIM_NORMALIZE = 50; -const POINT_COLOR_MISSING = 'black'; - -export let ProjectorPolymer = PolymerElement({ - is: 'vz-projector', - properties: { - routePrefix: String, - dataProto: {type: String, observer: '_dataProtoChanged'}, - servingMode: String, - projectorConfigJsonPath: String, - pageViewLogging: Boolean, - eventLogging: Boolean - } -}); - -const INDEX_METADATA_FIELD = '__index__'; - -export class Projector extends ProjectorPolymer implements - ProjectorEventContext { - // The working subset of the data source's original data set. - dataSet: DataSet; - servingMode: ServingMode; - // The path to the projector config JSON file for demo mode. - projectorConfigJsonPath: string; - - private selectionChangedListeners: SelectionChangedListener[]; - private hoverListeners: HoverListener[]; - private projectionChangedListeners: ProjectionChangedListener[]; - private distanceMetricChangedListeners: DistanceMetricChangedListener[]; - - private originalDataSet: DataSet; - private dataSetBeforeFilter: DataSet; - private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter; - private dim: number; - - private dataSetFilterIndices: number[]; - private selectedPointIndices: number[]; - private neighborsOfFirstPoint: knn.NearestEntry[]; - private hoverPointIndex: number; - - private dataProvider: DataProvider; - private inspectorPanel: InspectorPanel; - - private selectedColorOption: ColorOption; - private selectedLabelOption: string; - private routePrefix: string; - private normalizeData: boolean; - private projection: Projection; - - /** Polymer component panels */ - private dataPanel: DataPanel; - private bookmarkPanel: BookmarkPanel; - private projectionsPanel: ProjectionsPanel; - private metadataCard: MetadataCard; - - private statusBar: HTMLDivElement; - private analyticsLogger: AnalyticsLogger; - private eventLogging: boolean; - private pageViewLogging: boolean; - - ready() { - logging.setDomContainer(this); - - this.analyticsLogger = - new AnalyticsLogger(this.pageViewLogging, this.eventLogging); - this.analyticsLogger.logPageView('embeddings'); - - if (!util.hasWebGLSupport()) { - this.analyticsLogger.logWebGLDisabled(); - logging.setErrorMessage( - 'Your browser or device does not have WebGL enabled. Please enable ' + - 'hardware acceleration, or use a browser that supports WebGL.'); - return; - } - - this.selectionChangedListeners = []; - this.hoverListeners = []; - this.projectionChangedListeners = []; - this.distanceMetricChangedListeners = []; - this.selectedPointIndices = []; - this.neighborsOfFirstPoint = []; - - this.dataPanel = this.$['data-panel'] as DataPanel; - this.inspectorPanel = this.$['inspector-panel'] as InspectorPanel; - this.inspectorPanel.initialize(this, this as ProjectorEventContext); - this.projectionsPanel = this.$['projections-panel'] as ProjectionsPanel; - this.projectionsPanel.initialize(this); - this.bookmarkPanel = this.$['bookmark-panel'] as BookmarkPanel; - this.bookmarkPanel.initialize(this, this as ProjectorEventContext); - this.metadataCard = this.$['metadata-card'] as MetadataCard; - this.statusBar = this.querySelector('#status-bar') as HTMLDivElement; - this.scopeSubtree(this.$$('#notification-dialog'), true); - this.setupUIControls(); - this.initializeDataProvider(); - } - - setSelectedLabelOption(labelOption: string) { - this.selectedLabelOption = labelOption; - this.metadataCard.setLabelOption(this.selectedLabelOption); - this.projectorScatterPlotAdapter.setLabelPointAccessor(labelOption); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.render(); - } - - setSelectedColorOption(colorOption: ColorOption) { - this.selectedColorOption = colorOption; - this.projectorScatterPlotAdapter.setLegendPointColorer( - this.getLegendPointColorer(colorOption)); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.render(); - } - - setNormalizeData(normalizeData: boolean) { - this.normalizeData = normalizeData; - this.setCurrentDataSet(this.originalDataSet.getSubset()); - } - - updateDataSet( - ds: DataSet, spriteAndMetadata?: SpriteAndMetadataInfo, - metadataFile?: string) { - this.dataSetFilterIndices = null; - this.originalDataSet = ds; - if (ds != null) { - this.normalizeData = - this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE; - spriteAndMetadata = spriteAndMetadata || {}; - if (spriteAndMetadata.pointsInfo == null) { - let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points); - spriteAndMetadata.pointsInfo = pointsInfo; - spriteAndMetadata.stats = stats; - } - let metadataMergeSucceeded = ds.mergeMetadata(spriteAndMetadata); - if (!metadataMergeSucceeded) { - return; - } - } - if (this.projectorScatterPlotAdapter != null) { - if (ds == null) { - this.projectorScatterPlotAdapter.setLabelPointAccessor(null); - this.setProjection(null); - } else { - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.resize(); - this.projectorScatterPlotAdapter.render(); - } - } - if (ds != null) { - this.dataPanel.setNormalizeData(this.normalizeData); - this.setCurrentDataSet(ds.getSubset()); - this.projectorScatterPlotAdapter.setLabelPointAccessor( - this.selectedLabelOption); - this.inspectorPanel.datasetChanged(); - - this.inspectorPanel.metadataChanged(spriteAndMetadata); - this.projectionsPanel.metadataChanged(spriteAndMetadata); - this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile); - // Set the container to a fixed height, otherwise in Colab the - // height can grow indefinitely. - const container = this.querySelector('#container') as HTMLDivElement; - container.style.height = container.clientHeight + 'px'; - } else { - this.setCurrentDataSet(null); - } - } - - setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) { - this.bookmarkPanel.setSelectedTensor(run, tensorInfo, this.dataProvider); - } - - /** - * Registers a listener to be called any time the selected point set changes. - */ - registerSelectionChangedListener(listener: SelectionChangedListener) { - this.selectionChangedListeners.push(listener); - } - - filterDataset(pointIndices: number[]) { - const selectionSize = this.selectedPointIndices.length; - if (this.dataSetBeforeFilter == null) { - this.dataSetBeforeFilter = this.dataSet; - } - this.setCurrentDataSet(this.dataSet.getSubset(pointIndices)); - this.dataSetFilterIndices = pointIndices; - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.adjustSelectionAndHover(util.range(selectionSize)); - } - - resetFilterDataset() { - const originalPointIndices = this.selectedPointIndices.map( - filteredIndex => this.dataSet.points[filteredIndex].index); - this.setCurrentDataSet(this.dataSetBeforeFilter); - if (this.projection != null) { - this.projection.dataSet = this.dataSetBeforeFilter; - } - this.dataSetBeforeFilter = null; - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.dataSetFilterIndices = []; - this.adjustSelectionAndHover(originalPointIndices); - } - - /** - * Used by clients to indicate that a selection has occurred. - */ - notifySelectionChanged(newSelectedPointIndices: number[]) { - this.selectedPointIndices = newSelectedPointIndices; - let neighbors: knn.NearestEntry[] = []; - - if (newSelectedPointIndices.length === 1) { - neighbors = this.dataSet.findNeighbors( - newSelectedPointIndices[0], this.inspectorPanel.distFunc, - this.inspectorPanel.numNN); - this.metadataCard.updateMetadata( - this.dataSet.points[newSelectedPointIndices[0]].metadata); - } else { - this.metadataCard.updateMetadata(null); - } - - this.selectionChangedListeners.forEach( - l => l(this.selectedPointIndices, neighbors)); - } - - /** - * Registers a listener to be called any time the mouse hovers over a point. - */ - registerHoverListener(listener: HoverListener) { - this.hoverListeners.push(listener); - } - - /** - * Used by clients to indicate that a hover is occurring. - */ - notifyHoverOverPoint(pointIndex: number) { - this.hoverListeners.forEach(l => l(pointIndex)); - } - - registerProjectionChangedListener(listener: ProjectionChangedListener) { - this.projectionChangedListeners.push(listener); - } - - notifyProjectionChanged(projection: Projection) { - this.projectionChangedListeners.forEach(l => l(projection)); - } - - registerDistanceMetricChangedListener(l: DistanceMetricChangedListener) { - this.distanceMetricChangedListeners.push(l); - } - - notifyDistanceMetricChanged(distMetric: DistanceFunction) { - this.distanceMetricChangedListeners.forEach(l => l(distMetric)); - } - - _dataProtoChanged(dataProtoString: string) { - let dataProto = - dataProtoString ? JSON.parse(dataProtoString) as DataProto : null; - this.initializeDataProvider(dataProto); - } - - private makeDefaultPointsInfoAndStats(points: DataPoint[]): - [PointMetadata[], ColumnStats[]] { - let pointsInfo: PointMetadata[] = []; - points.forEach(p => { - let pointInfo: PointMetadata = {}; - pointInfo[INDEX_METADATA_FIELD] = p.index; - pointsInfo.push(pointInfo); - }); - let stats: ColumnStats[] = [{ - name: INDEX_METADATA_FIELD, - isNumeric: false, - tooManyUniqueValues: true, - min: 0, - max: pointsInfo.length - 1 - }]; - return [pointsInfo, stats]; - } - - private initializeDataProvider(dataProto?: DataProto) { - if (this.servingMode === 'demo') { - let projectorConfigUrl: string; - - // Only in demo mode do we allow the config being passed via URL. - let urlParams = util.getURLParams(window.location.search); - if ('config' in urlParams) { - projectorConfigUrl = urlParams['config']; - } else { - projectorConfigUrl = this.projectorConfigJsonPath; - } - this.dataProvider = new DemoDataProvider(projectorConfigUrl); - } else if (this.servingMode === 'server') { - if (!this.routePrefix) { - throw 'route-prefix is a required parameter'; - } - this.dataProvider = new ServerDataProvider(this.routePrefix); - } else if (this.servingMode === 'proto' && dataProto != null) { - this.dataProvider = new ProtoDataProvider(dataProto); - } - - this.dataPanel.initialize(this, this.dataProvider); - } - - private getLegendPointColorer(colorOption: ColorOption): - (ds: DataSet, index: number) => string { - if ((colorOption == null) || (colorOption.map == null)) { - return null; - } - const colorer = (ds: DataSet, i: number) => { - let value = ds.points[i].metadata[this.selectedColorOption.name]; - if (value == null) { - return POINT_COLOR_MISSING; - } - return colorOption.map(value); - }; - return colorer; - } - - private get3DLabelModeButton(): any { - return this.querySelector('#labels3DMode'); - } - - private get3DLabelMode(): boolean { - const label3DModeButton = this.get3DLabelModeButton(); - return (label3DModeButton as any).active; - } - - adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) { - this.notifySelectionChanged(selectedPointIndices); - this.notifyHoverOverPoint(hoverIndex); - this.setMouseMode(MouseMode.CAMERA_AND_CLICK_SELECT); - } - - private setMouseMode(mouseMode: MouseMode) { - let selectModeButton = this.querySelector('#selectMode'); - (selectModeButton as any).active = (mouseMode === MouseMode.AREA_SELECT); - this.projectorScatterPlotAdapter.scatterPlot.setMouseMode(mouseMode); - } - - private setCurrentDataSet(ds: DataSet) { - this.adjustSelectionAndHover([]); - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - if ((ds != null) && this.normalizeData) { - ds.normalize(); - } - this.dim = (ds == null) ? 0 : ds.dim[1]; - (this.querySelector('span.numDataPoints') as HTMLSpanElement).innerText = - (ds == null) ? '0' : '' + ds.dim[0]; - (this.querySelector('span.dim') as HTMLSpanElement).innerText = - (ds == null) ? '0' : '' + ds.dim[1]; - - this.dataSet = ds; - - this.projectionsPanel.dataSetUpdated( - this.dataSet, this.originalDataSet, this.dim); - - this.projectorScatterPlotAdapter.setDataSet(this.dataSet); - this.projectorScatterPlotAdapter.scatterPlot - .setCameraParametersForNextCameraCreation(null, true); - } - - private setupUIControls() { - // View controls - this.querySelector('#reset-zoom').addEventListener('click', () => { - this.projectorScatterPlotAdapter.scatterPlot.resetZoom(); - this.projectorScatterPlotAdapter.scatterPlot.startOrbitAnimation(); - }); - - let selectModeButton = this.querySelector('#selectMode'); - selectModeButton.addEventListener('click', (event) => { - this.setMouseMode( - (selectModeButton as any).active ? MouseMode.AREA_SELECT : - MouseMode.CAMERA_AND_CLICK_SELECT); - }); - let nightModeButton = this.querySelector('#nightDayMode'); - nightModeButton.addEventListener('click', () => { - this.projectorScatterPlotAdapter.scatterPlot.setDayNightMode( - (nightModeButton as any).active); - }); - - const labels3DModeButton = this.get3DLabelModeButton(); - labels3DModeButton.addEventListener('click', () => { - this.projectorScatterPlotAdapter.set3DLabelMode(this.get3DLabelMode()); - }); - - window.addEventListener('resize', () => { - const container = this.querySelector('#container') as HTMLDivElement; - const parentHeight = (container.parentNode as HTMLElement).clientHeight; - container.style.height = parentHeight + 'px'; - this.projectorScatterPlotAdapter.resize(); - }); - - { - this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter( - this.getScatterContainer(), this as ProjectorEventContext); - this.projectorScatterPlotAdapter.setLabelPointAccessor( - this.selectedLabelOption); - } - - this.projectorScatterPlotAdapter.scatterPlot.onCameraMove( - (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => - this.bookmarkPanel.clearStateSelection()); - - this.registerHoverListener( - (hoverIndex: number) => this.onHover(hoverIndex)); - - this.registerSelectionChangedListener( - (selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[]) => - this.onSelectionChanged( - selectedPointIndices, neighborsOfFirstPoint)); - } - - private onHover(hoverIndex: number) { - this.hoverPointIndex = hoverIndex; - let hoverText = null; - if (hoverIndex != null) { - const point = this.dataSet.points[hoverIndex]; - if (point.metadata[this.selectedLabelOption]) { - hoverText = point.metadata[this.selectedLabelOption].toString(); - } - } - if (this.selectedPointIndices.length === 0) { - this.statusBar.style.display = hoverText ? null : 'none'; - this.statusBar.innerText = hoverText; - } - } - - private getScatterContainer(): HTMLDivElement { - return this.querySelector('#scatter') as HTMLDivElement; - } - - private onSelectionChanged( - selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[]) { - this.selectedPointIndices = selectedPointIndices; - this.neighborsOfFirstPoint = neighborsOfFirstPoint; - let totalNumPoints = - this.selectedPointIndices.length + neighborsOfFirstPoint.length; - this.statusBar.innerText = `Selected ${totalNumPoints} points`; - this.statusBar.style.display = totalNumPoints > 0 ? null : 'none'; - } - - setProjection(projection: Projection) { - this.projection = projection; - if (projection != null) { - this.analyticsLogger.logProjectionChanged(projection.projectionType); - } - this.notifyProjectionChanged(projection); - } - - notifyProjectionPositionsUpdated() { - this.projectorScatterPlotAdapter.notifyProjectionPositionsUpdated(); - } - - /** - * Gets the current view of the embedding and saves it as a State object. - */ - getCurrentState(): State { - const state = new State(); - - // Save the individual datapoint projections. - state.projections = []; - for (let i = 0; i < this.dataSet.points.length; i++) { - const point = this.dataSet.points[i]; - const projections: {[key: string]: number} = {}; - const keys = Object.keys(point.projections); - for (let j = 0; j < keys.length; ++j) { - projections[keys[j]] = point.projections[keys[j]]; - } - state.projections.push(projections); - } - state.selectedProjection = this.projection.projectionType; - state.dataSetDimensions = this.dataSet.dim; - state.tSNEIteration = this.dataSet.tSNEIteration; - state.selectedPoints = this.selectedPointIndices; - state.filteredPoints = this.dataSetFilterIndices; - this.projectorScatterPlotAdapter.populateBookmarkFromUI(state); - state.selectedColorOptionName = this.dataPanel.selectedColorOptionName; - state.forceCategoricalColoring = this.dataPanel.forceCategoricalColoring; - state.selectedLabelOption = this.selectedLabelOption; - this.projectionsPanel.populateBookmarkFromUI(state); - return state; - } - - /** Loads a State object into the world. */ - loadState(state: State) { - this.setProjection(null); - { - this.projectionsPanel.disablePolymerChangesTriggerReprojection(); - if (this.dataSetBeforeFilter != null) { - this.resetFilterDataset(); - } - if (state.filteredPoints != null) { - this.filterDataset(state.filteredPoints); - } - this.projectionsPanel.enablePolymerChangesTriggerReprojection(); - } - for (let i = 0; i < state.projections.length; i++) { - const point = this.dataSet.points[i]; - const projection = state.projections[i]; - const keys = Object.keys(projection); - for (let j = 0; j < keys.length; ++j) { - point.projections[keys[j]] = projection[keys[j]]; - } - } - this.dataSet.hasTSNERun = (state.selectedProjection === 'tsne'); - this.dataSet.tSNEIteration = state.tSNEIteration; - this.projectionsPanel.restoreUIFromBookmark(state); - this.inspectorPanel.restoreUIFromBookmark(state); - this.dataPanel.selectedColorOptionName = state.selectedColorOptionName; - this.dataPanel.setForceCategoricalColoring( - !!state.forceCategoricalColoring); - this.selectedLabelOption = state.selectedLabelOption; - this.projectorScatterPlotAdapter.restoreUIFromBookmark(state); - { - const dimensions = stateGetAccessorDimensions(state); - const components = - data.getProjectionComponents(state.selectedProjection, dimensions); - const projection = new Projection( - state.selectedProjection, components, dimensions.length, - this.dataSet); - this.setProjection(projection); - } - this.notifySelectionChanged(state.selectedPoints); - } -} - -document.registerElement(Projector.prototype.is, Projector); diff --git a/tensorflow/tensorboard/components/vz_sorting/BUILD b/tensorflow/tensorboard/components/vz_sorting/BUILD deleted file mode 100644 index e06b8ae19790490e73d3ceb552ea03d9f304e68d..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "vz_sorting", - srcs = [ - "sorting.ts", - "vz-sorting.html", - ], - path = "/vz-sorting", - visibility = ["//visibility:public"], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":vz_sorting"], - destdir = "vz-sorting", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_sorting/sorting.ts b/tensorflow/tensorboard/components/vz_sorting/sorting.ts deleted file mode 100644 index 061184d24bf30623e05834269b32acf745a56299..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/sorting.ts +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * Compares tag names asciinumerically broken into components. - * - *

This is the comparison function used for sorting most string values in - * TensorBoard. Unlike the standard asciibetical comparator, this function - * knows that 'a10b' > 'a2b'. Fixed point and engineering notation are - * supported. This function also splits the input by slash and underscore to - * perform array comparison. Therefore it knows that 'a/a' < 'a+/a' even - * though '+' < '/' in the ASCII table. - */ -export function compareTagNames(a, b: string): number { - let ai = 0; - let bi = 0; - while (true) { - if (ai === a.length) { - return bi === b.length ? 0 : -1; - } - if (bi === b.length) { - return 1; - } - if (isDigit(a[ai]) && isDigit(b[bi])) { - const ais = ai; - const bis = bi; - ai = consumeNumber(a, ai + 1); - bi = consumeNumber(b, bi + 1); - const an = parseFloat(a.slice(ais, ai)); - const bn = parseFloat(b.slice(bis, bi)); - if (an < bn) { - return -1; - } - if (an > bn) { - return 1; - } - continue; - } - if (isBreak(a[ai])) { - if (!isBreak(b[bi])) { - return -1; - } - } else if (isBreak(b[bi])) { - return 1; - } else if (a[ai] < b[bi]) { - return -1; - } else if (a[ai] > b[bi]) { - return 1; - } - ai++; - bi++; - } -} - -function consumeNumber(s: string, i: number): number { - enum State { NATURAL, REAL, EXPONENT_SIGN, EXPONENT } - let state = State.NATURAL; - for (; i < s.length; i++) { - if (state === State.NATURAL) { - if (s[i] === '.') { - state = State.REAL; - } else if (s[i] === 'e' || s[i] === 'E') { - state = State.EXPONENT_SIGN; - } else if (!isDigit(s[i])) { - break; - } - } else if (state === State.REAL) { - if (s[i] === 'e' || s[i] === 'E') { - state = State.EXPONENT_SIGN; - } else if (!isDigit(s[i])) { - break; - } - } else if (state === State.EXPONENT_SIGN) { - if (isDigit(s[i]) || s[i] === '+' || s[i] === '-') { - state = State.EXPONENT; - } else { - break; - } - } else if (state === State.EXPONENT) { - if (!isDigit(s[i])) { - break; - } - } - } - return i; -} - -function isDigit(c: string): boolean { - return '0' <= c && c <= '9'; -} - -function isBreak(c: string): boolean { - // TODO(jart): Remove underscore when people stop using it like a slash. - return c === '/' || c === '_' || isDigit(c); -} diff --git a/tensorflow/tensorboard/components/vz_sorting/test/BUILD b/tensorflow/tensorboard/components/vz_sorting/test/BUILD deleted file mode 100644 index 929e80d37282387823ea4a93874a112710269cc1..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:vulcanize.bzl", "tensorboard_html_binary") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "sortingTests.ts", - "tests.html", - ], - path = "/vz-sorting/test", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/vz_sorting", - ], -) - -tensorboard_html_binary( - name = "devserver", - compilation_level = "WHITESPACE_ONLY", - input_path = "/vz-sorting/test/tests.html", - output_path = "/vz-sorting/test/tests.html", - deps = [":test"], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts b/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts deleted file mode 100644 index 510685cb4b5e42ca19e56acef6b1f87347811c99..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {compareTagNames} from '../sorting'; - -describe('compareTagNames', () => { - - const assert = chai.assert; - const sortTagNames = (a) => a.sort(compareTagNames); - - it('is asciibetical', () => { - assert.deepEqual(sortTagNames(['a', 'b']), ['a', 'b']); - assert.deepEqual(sortTagNames(['a', 'B']), ['B', 'a']); - }); - - it('sorts integer portions', () => { - assert.deepEqual(['03', '1'].sort(), ['03', '1']); - assert.deepEqual(sortTagNames(['03', '1']), ['1', '03']); - assert.deepEqual(sortTagNames(['a03', 'a1']), ['a1', 'a03']); - assert.deepEqual(sortTagNames(['a03', 'b1']), ['a03', 'b1']); - assert.deepEqual(sortTagNames(['x0a03', 'x0a1']), ['x0a1', 'x0a03']); - assert.deepEqual(sortTagNames(['a/b/03', 'a/b/1']), ['a/b/1', 'a/b/03']); - }); - - it('sorts fixed point numbers', () => { - assert.deepEqual(sortTagNames(['a0.1', 'a0.01']), ['a0.01', 'a0.1']); - }); - - it('sorts engineering notation', () => { - assert.deepEqual(sortTagNames(['a1e9', 'a9e8']), ['a9e8', 'a1e9']); - assert.deepEqual(sortTagNames(['a1e+9', 'a9e+8']), ['a9e+8', 'a1e+9']); - assert.deepEqual(sortTagNames(['a1e+5', 'a9e-6']), ['a9e-6', 'a1e+5']); - assert.deepEqual(sortTagNames(['a1.0e9', 'a9.0e8']), ['a9.0e8', 'a1.0e9']); - assert.deepEqual( - sortTagNames(['a1.0e+9', 'a9.0e+8']), ['a9.0e+8', 'a1.0e+9']); - }); - - it('is componentized by slash', () => { - assert.deepEqual(['a+/a', 'a/a', 'ab/a'].sort(), ['a+/a', 'a/a', 'ab/a']); - assert.deepEqual( - sortTagNames(['a+/a', 'a/a', 'ab/a']), ['a/a', 'a+/a', 'ab/a']); - }); - - it('is componentized by underscore', () => { - assert.deepEqual( - sortTagNames(['a+_a', 'a_a', 'ab_a']), ['a_a', 'a+_a', 'ab_a']); - assert.deepEqual( - sortTagNames(['a+/a', 'a_a', 'ab_a']), ['a_a', 'a+/a', 'ab_a']); - }); - - it('is componentized by number boundaries', () => { - assert.deepEqual( - sortTagNames(['a+0a', 'a0a', 'ab0a']), ['a0a', 'a+0a', 'ab0a']); - }); - - it('empty comes first', () => { - assert.deepEqual(sortTagNames(['a', '//', '/', '']), ['', '/', '//', 'a']); - }); - - it('decimal parsed correctly', () => { - assert.deepEqual(sortTagNames(['0.2', '0.03']), ['0.03', '0.2']); - assert.deepEqual(sortTagNames(['0..2', '0..03']), ['0..2', '0..03']); - assert.deepEqual(sortTagNames(['.2', '.03']), ['.2', '.03']); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_sorting/test/tests.html b/tensorflow/tensorboard/components/vz_sorting/test/tests.html deleted file mode 100644 index f92c608cdb125ec7e6d6b538d089f2779732ce6a..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/tests.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html b/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html deleted file mode 100644 index 5ff6f311589d2ef1c65dbfb052d255390c36991f..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/defs/BUILD b/tensorflow/tensorboard/defs/BUILD deleted file mode 100644 index 92a2af34048deaf6da07a7b14aa42e4cd8202958..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -filegroup( - name = "ts_web_library_default_typings", - srcs = [ - # Ordering probably matters. - "@com_microsoft_typescript//:lib.es6.d.ts", - "@io_angular_clutz//:src/resources/closure.lib.d.ts", - "clutz.d.ts", - ], - visibility = ["//visibility:public"], -) diff --git a/tensorflow/tensorboard/defs/clutz.d.ts b/tensorflow/tensorboard/defs/clutz.d.ts deleted file mode 100644 index 47cf307d2619a4a84f631dceb03b393cd04aa0d6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/clutz.d.ts +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// tslint:disable -declare namespace ಠ_ಠ.clutz { - interface IteratorIterable extends Iterator, Iterable {} - interface IIterableResult extends IteratorResult {} -} diff --git a/tensorflow/tensorboard/defs/hacks.bzl b/tensorflow/tensorboard/defs/hacks.bzl deleted file mode 100644 index f1d4be790612ac912dc1b1a2298f8bc8dd99dee6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/hacks.bzl +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO(jart): Merge this file into defs.bzl once that file is sync unified. - -def tensorboard_typescript_bundle( - name, - out, - namespace_srcs, - namespace_symbol_aliases={}, - namespace_symbol_aliases_public={}, - **kwargs): - """Rolls TypeScript ES6 modules into one vanilla source file without imports. - - This is a genrule wrapper that concatenates TypeScripts sources inside - namespace blocks while removing ^import lines. Because the sources themselves - are not parsed, the structure of the modules must be passed to this macro as - a Skylark data structure. - - Args: - name: Name of this build rule target. - out: Path of outputted TypeScript source file. - namespace_srcs: Multimap of namespace strings to build file targets. The - ordering of the dictionary and nested lists does not matter when - generating a typings file, but *does* matter when generating a source - file. - namespace_symbol_aliases: Map of namespace strings where each value is a - map of symbol names to fully qualified symbol names. - namespace_symbol_aliases_public: Same as namespace_symbol_aliases but the - symbol will be visible to other namespaces. - """ - cmd = ["(", "echo // GENERATED BY TENSORBOARD_TYPESCRIPT_BUNDLE"] - inputs = set() - for namespace, srcs in namespace_srcs.items(): - cmd.append("echo") - if out[-5:] == ".d.ts": - cmd.append("echo 'declare namespace %s {'" % namespace) - elif out[-3:] == ".ts": - cmd.append("echo 'module %s {'" % namespace) - else: - fail("'out' must end with .ts or .d.ts: " + out) - for symbol, canon in namespace_symbol_aliases.get(namespace, {}).items(): - cmd.append("echo 'import %s = %s;'" % (symbol, canon)) - for symbol, canon in namespace_symbol_aliases_public.get(namespace, - {}).items(): - cmd.append("echo 'export import %s = %s;'" % (symbol, canon)) - inputs += srcs - for src in srcs: - cmd.append("for f in $(locations %s); do" % src) - cmd.append(" echo") - cmd.append(" echo /////////////////////////////////////////////////////") - cmd.append(" echo // " + namespace) - cmd.append(" echo // $$f") - cmd.append(" echo /////////////////////////////////////////////////////") - cmd.append(" echo") - cmd.append(" sed 's!^import !// import !' $$f \\") - cmd.append(" | sed 's!^export declare !export !' \\") - cmd.append(" | sed '/^export .* from /d' \\") - cmd.append(" | sed '/^export {.*};$$/d'") - cmd.append("done") - cmd.append("echo '}'") - cmd.append(") >$@") - native.genrule( - name = name, - srcs = list(inputs), - outs = [out], - cmd = "\n".join(cmd), - **kwargs - ) diff --git a/tensorflow/tensorboard/defs/protos.bzl b/tensorflow/tensorboard/defs/protos.bzl deleted file mode 100644 index 6d1982e098d9c549a3f6387035c6877d0b798ab7..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/protos.bzl +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@protobuf//:protobuf.bzl", "py_proto_library") - -def tb_proto_library(name, srcs = [], visibility = []): - py_proto_library( - name = name + "_py", - srcs = srcs, - srcs_version = "PY2AND3", - deps = ["@protobuf//:protobuf_python"], - protoc = "@protobuf//:protoc", - visibility = visibility, - default_runtime = "@protobuf//:protobuf_python", - testonly = 0, - ) \ No newline at end of file diff --git a/tensorflow/tensorboard/defs/vulcanize.bzl b/tensorflow/tensorboard/defs/vulcanize.bzl deleted file mode 100644 index 6ff49a35ed73f0a8a5fb7ce5b3544e0807e1c0bc..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/vulcanize.bzl +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/tensorboard/defs:defs.bzl", "legacy_js") -load("@io_bazel_rules_closure//closure/private:defs.bzl", "collect_js", "unfurl", "long_path") -load("//tensorflow/tensorboard/defs:web.bzl", "web_aspect") - -def _tensorboard_html_binary(ctx): - deps = unfurl(ctx.attr.deps, provider="webfiles") - manifests = set(order="topological") - files = set() - webpaths = set() - for dep in deps: - manifests += dep.webfiles.manifests - webpaths += dep.webfiles.webpaths - files += dep.data_runfiles.files - webpaths += [ctx.attr.output_path] - closure_js_library=collect_js( - ctx, unfurl(ctx.attr.deps, provider="closure_js_library")) - - # vulcanize - jslibs = depset(ctx.files._jslibs) + closure_js_library.srcs - ctx.action( - inputs=list(manifests | files | jslibs), - outputs=[ctx.outputs.html], - executable=ctx.executable._Vulcanize, - arguments=([ctx.attr.compilation_level, - "true" if ctx.attr.testonly else "false", - ctx.attr.input_path, - ctx.attr.output_path, - ctx.outputs.html.path] + - [f.path for f in jslibs] + - [f.path for f in manifests]), - progress_message="Vulcanizing %s" % ctx.attr.input_path) - - # webfiles manifest - manifest_srcs = [struct(path=ctx.outputs.html.path, - longpath=long_path(ctx, ctx.outputs.html), - webpath=ctx.attr.output_path)] - manifest = ctx.new_file(ctx.configuration.bin_dir, - "%s.pbtxt" % ctx.label.name) - ctx.file_action( - output=manifest, - content=struct( - label=str(ctx.label), - src=manifest_srcs).to_proto()) - manifests += [manifest] - - # webfiles server - params = struct( - label=str(ctx.label), - bind="[::]:6006", - manifest=[long_path(ctx, man) for man in manifests], - external_asset=[struct(webpath=k, path=v) - for k, v in ctx.attr.external_assets.items()]) - params_file = ctx.new_file(ctx.configuration.bin_dir, - "%s_server_params.pbtxt" % ctx.label.name) - ctx.file_action(output=params_file, content=params.to_proto()) - ctx.file_action( - executable=True, - output=ctx.outputs.executable, - content="#!/bin/sh\nexec %s %s" % ( - ctx.executable._WebfilesServer.short_path, - long_path(ctx, params_file))) - - transitive_runfiles = depset() - transitive_runfiles += ctx.attr._WebfilesServer.data_runfiles.files - for dep in deps: - transitive_runfiles += dep.data_runfiles.files - return struct( - files=depset([ctx.outputs.html]), - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=ctx.outputs.html), - runfiles=ctx.runfiles( - files=ctx.files.data + [manifest, - params_file, - ctx.outputs.html, - ctx.outputs.executable], - transitive_files=transitive_runfiles)) - -tensorboard_html_binary = rule( - implementation=_tensorboard_html_binary, - executable=True, - attrs={ - "compilation_level": attr.string(default="ADVANCED"), - "input_path": attr.string(mandatory=True), - "output_path": attr.string(mandatory=True), - "data": attr.label_list(cfg="data", allow_files=True), - "deps": attr.label_list( - aspects=[ - web_aspect, - legacy_js, - ], - mandatory=True), - "external_assets": attr.string_dict(default={"/_/runfiles": "."}), - "_jslibs": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:jslibs"), - allow_files=True), - "_Vulcanize": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:Vulcanize"), - executable=True, - cfg="host"), - "_WebfilesServer": attr.label( - default=Label( - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles/server:WebfilesServer"), - executable=True, - cfg="host"), - }, - outputs={ - "html": "%{name}.html", - }) diff --git a/tensorflow/tensorboard/defs/web.bzl b/tensorflow/tensorboard/defs/web.bzl deleted file mode 100644 index 103942b0a25d2706b1af445383689dca02407d91..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/web.bzl +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Same as web_library but supports TypeScript.""" - -load("//tensorflow/tensorboard/defs:defs.bzl", "legacy_js") - -load("//third_party:clutz.bzl", - "CLUTZ_ATTRIBUTES", - "CLUTZ_OUTPUTS", - "clutz_aspect", - "extract_dts_from_closure_libraries") - -load("@io_bazel_rules_closure//closure/private:defs.bzl", - "CLOSURE_LIBRARY_BASE_ATTR", - "CLOSURE_LIBRARY_DEPS_ATTR", - "collect_js", - "collect_runfiles", - "convert_path_to_es6_module_name", - "create_argfile", - "difference", - "long_path", - "unfurl") - -_ASPECT_SLURP_FILE_TYPE = FileType([ - ".html", ".js", ".css", ".gss", ".png", ".jpg", ".gif", ".ico", ".svg"]) - -_CLOSURE_WORKER = attr.label( - default=Label("@io_bazel_rules_closure//java/io/bazel/rules/closure:ClosureWorker"), - executable=True, - cfg="host") - -def _ts_web_library(ctx): - if not ctx.attr.srcs: - if ctx.attr.deps: - fail("deps can not be set when srcs is not") - if not ctx.attr.exports: - fail("exports must be set if srcs is not") - if ctx.attr.path: - if not ctx.attr.path.startswith("/"): - fail("webpath must start with /") - if ctx.attr.path != "/" and ctx.attr.path.endswith("/"): - fail("webpath must not end with / unless it is /") - if "//" in ctx.attr.path: - fail("webpath must not have //") - elif ctx.attr.srcs: - fail("path must be set when srcs is set") - if "*" in ctx.attr.suppress and len(ctx.attr.suppress) != 1: - fail("when \"*\" is suppressed no other items should be present") - - # process what came before - deps = unfurl(ctx.attr.deps, provider="webfiles") - webpaths = depset() - ts_typings = depset(ctx.files._default_typings) - ts_typings_paths = depset( - [long_path(ctx, f) for f in ctx.files._default_typings]) - ts_typings_execroots = depset() - aspect_runfiles = depset() - for dep in deps: - webpaths += dep.webfiles.webpaths - if hasattr(dep.webfiles, "ts_typings"): - ts_typings += dep.webfiles.ts_typings - if hasattr(dep.webfiles, "ts_typings_paths"): - ts_typings_paths += dep.webfiles.ts_typings_paths - if hasattr(dep.webfiles, "ts_typings_execroots"): - ts_typings_execroots += dep.webfiles.ts_typings_execroots - if hasattr(dep.webfiles, "aspect_runfiles"): - aspect_runfiles += dep.webfiles.aspect_runfiles - - # process what comes now - manifest_srcs = [] - new_webpaths = [] - ts_inputs = depset() - ts_outputs = [] - ts_files = list(ts_typings_paths) - new_typings = [] - new_typings_paths = [] - new_typings_execroot = struct(inputs=[]) - execroot = struct( - inputs=[(long_path(ctx, f), f.path) for f in ctx.files._default_typings], - outputs=[], - program=[ctx.executable._tsc.path, "-p"]) - web_srcs = [] - path = ctx.attr.path - strip = _get_strip(ctx) - for src in ctx.files.srcs: - suffix = _get_path_relative_to_package(src) - if strip: - if not suffix.startswith(strip): - fail("Relative src path not start with '%s': %s" % (strip, suffix)) - suffix = suffix[len(strip):] - webpath = "%s/%s" % ("" if path == "/" else path, suffix) - _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs) - if suffix.endswith(".d.ts"): - web_srcs.append(src) - entry = (webpath[1:], src.path) - new_typings.append(src) - new_typings_paths.append(entry[0]) - new_typings_execroot.inputs.append(entry) - ts_inputs += [src] - ts_files.append(entry[0]) - execroot.inputs.append(entry) - elif suffix.endswith(".ts"): - noext = suffix[:-3] - js = ctx.new_file(ctx.bin_dir, "%s.js" % noext) - dts = ctx.new_file(ctx.bin_dir, "%s.d.ts" % noext) - webpath_js = webpath[:-3] + ".js" - webpath_dts = webpath[:-3] + ".d.ts" - _add_webpath(ctx, js, webpath_js, webpaths, new_webpaths, manifest_srcs) - _add_webpath(ctx, dts, webpath_dts, webpaths, new_webpaths, manifest_srcs) - ts_inputs += [src] - ts_outputs.append(js) - ts_outputs.append(dts) - web_srcs.append(dts) - web_srcs.append(js) - ts_files.append(webpath[1:]) - execroot.inputs.append((webpath[1:], src.path)) - execroot.outputs.append((webpath_js[1:], js.path)) - execroot.outputs.append((webpath_dts[1:], dts.path)) - new_typings.append(dts) - new_typings_paths.append(webpath_dts[1:]) - new_typings_execroot.inputs.append((webpath_dts[1:], dts.path)) - else: - web_srcs.append(src) - - # get typings for closure code - clutz_dts = extract_dts_from_closure_libraries(ctx) - if clutz_dts: - entry = (long_path(ctx, clutz_dts), clutz_dts.path) - ts_inputs += [clutz_dts] - ts_files.append(entry[0]) - execroot.inputs.append(entry) - - # compile typescript - workspace = "" - if ctx.label.workspace_root: - workspace = "/" + ctx.label.workspace_root - if execroot.outputs: - ts_config = _new_file(ctx, "-tsc.json") - execroot.inputs.append(("tsconfig.json", ts_config.path)) - ctx.file_action( - output=ts_config, - content=struct( - compilerOptions=struct( - baseUrl=".", - declaration=True, - inlineSourceMap=True, - inlineSources=True, - module="es6", - moduleResolution="node", - noResolve=True, - target="es5", - ), - files=ts_files, - ).to_json()) - er_config = _new_file(ctx, "-tsc-execroot.json") - ctx.file_action(output=er_config, content=execroot.to_json()) - ts_inputs += collect_runfiles([ctx.attr._tsc]) - ts_inputs += ctx.files._tsc - ts_inputs += ts_typings - ts_inputs += ts_typings_execroots - ts_inputs += [ts_config, er_config] - ctx.action( - inputs=list(ts_inputs), - outputs=ts_outputs, - executable=ctx.executable._execrooter, - arguments=[er_config.path] + [f.path for f in ts_typings_execroots], - progress_message="Compiling %d TypeScript files %s" % ( - len(ts_files), ctx.label)) - - # perform strict dependency checking - manifest = _make_manifest(ctx, manifest_srcs) - webpaths += new_webpaths - dummy, manifests = _run_webfiles_validator(ctx, web_srcs, deps, manifest) - web_srcs.append(dummy) - - # define development web server that only applies to this transitive closure - params = struct( - label=str(ctx.label), - bind="[::]:6006", - manifest=[long_path(ctx, man) for man in manifests], - external_asset=[struct(webpath=k, path=v) - for k, v in ctx.attr.external_assets.items()]) - params_file = _new_file(ctx, "-params.pbtxt") - ctx.file_action(output=params_file, content=params.to_proto()) - ctx.file_action( - executable=True, - output=ctx.outputs.executable, - content="#!/bin/sh\nexec %s %s" % ( - ctx.executable._WebfilesServer.short_path, - long_path(ctx, params_file))) - - if new_typings: - er_config = _new_file(ctx, "-typings-execroot.json") - ctx.file_action(output=er_config, content=new_typings_execroot.to_json()) - ts_typings += new_typings - ts_typings_paths += new_typings_paths - ts_typings_execroots += [er_config] - else: - ts_typings = depset() - ts_typings_paths = depset() - ts_typings_execroots = depset() - - # export data to parent rules - return struct( - files=depset(web_srcs + [dummy]), - exports=unfurl(ctx.attr.exports), - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=dummy, - ts_typings=ts_typings, - ts_typings_paths=ts_typings_paths, - ts_typings_execroots=ts_typings_execroots), - closure_js_library=collect_js( - ctx, unfurl(ctx.attr.deps, provider="closure_js_library")), - runfiles=ctx.runfiles( - files=ctx.files.srcs + ctx.files.data + ts_outputs + [ - manifest, - params_file, - ctx.outputs.executable, - dummy], - transitive_files=(collect_runfiles([ctx.attr._WebfilesServer]) | - collect_runfiles(deps) | - collect_runfiles(ctx.attr.data) | - aspect_runfiles))) - -def _web_aspect_impl(target, ctx): - if hasattr(target, "webfiles"): - return struct() - srcs = [] - deps = [] - if hasattr(ctx.rule.files, "srcs"): - srcs.extend(_ASPECT_SLURP_FILE_TYPE.filter(ctx.rule.files.srcs)) - for attr in ("deps", "sticky_deps", "module_deps"): - value = getattr(ctx.rule.attr, attr, None) - if value: - deps.extend(value) - deps = unfurl(deps, provider="webfiles") - webpaths = depset() - aspect_runfiles = depset(srcs) - for dep in deps: - webpaths += dep.webfiles.webpaths - if hasattr(dep.webfiles, "aspect_runfiles"): - aspect_runfiles += dep.webfiles.aspect_runfiles - manifest_srcs = [] - new_webpaths = [] - for src in srcs: - webpath = "/" + long_path(ctx, src) - _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs) - webpaths += new_webpaths - manifest = _make_manifest(ctx, manifest_srcs) - dummy, manifests = _run_webfiles_validator(ctx, srcs, deps, manifest) - aspect_runfiles += [dummy, manifest] - return struct( - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=dummy, - aspect_runfiles=aspect_runfiles)) - -def _make_manifest(ctx, src_list): - manifest = _new_file(ctx, "-webfiles.pbtxt") - ctx.file_action( - output=manifest, - content=struct( - label=str(ctx.label), - src=src_list).to_proto()) - return manifest - -def _run_webfiles_validator(ctx, srcs, deps, manifest): - dummy = _new_file(ctx, "-webfiles.ignoreme") - manifests = depset(order="topological") - for dep in deps: - manifests += dep.webfiles.manifests - if srcs: - args = ["WebfilesValidator", - "--dummy", dummy.path, - "--target", manifest.path] - if hasattr(ctx, "attr") and hasattr(ctx.attr, "suppress"): - for category in ctx.attr.suppress: - args.append("--suppress") - args.append(category) - inputs = [manifest] - inputs.extend(srcs) - direct_manifests = depset() - for dep in deps: - inputs.append(dep.webfiles.dummy) - for f in dep.files: - inputs.append(f) - direct_manifests += [dep.webfiles.manifest] - inputs.append(dep.webfiles.manifest) - args.append("--direct_dep") - args.append(dep.webfiles.manifest.path) - for man in difference(manifests, direct_manifests): - inputs.append(man) - args.append("--transitive_dep") - args.append(man.path) - argfile = _new_file(ctx, "-webfiles-checker-args.txt") - ctx.file_action(output=argfile, content="\n".join(args)) - inputs.append(argfile) - ctx.action( - inputs=inputs, - outputs=[dummy], - executable=(getattr(ctx.executable, "_ClosureWorker", None) or - getattr(ctx.executable, "_ClosureWorkerAspect", None)), - arguments=["@@" + argfile.path], - mnemonic="Closure", - execution_requirements={"supports-workers": "1"}, - progress_message="Checking webfiles %s" % ctx.label) - else: - ctx.file_action(output=dummy, content="BOO!") - manifests += [manifest] - return dummy, manifests - -def _new_file(ctx, suffix): - return ctx.new_file(ctx.bin_dir, "%s%s" % (ctx.label.name, suffix)) - -def _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs): - if webpath in new_webpaths: - _fail(ctx, "multiple srcs within %s define the webpath %s " % ( - ctx.label, webpath)) - if webpath in webpaths: - _fail(ctx, "webpath %s was defined by %s when already defined by deps" % ( - webpath, ctx.label)) - new_webpaths.append(webpath) - manifest_srcs.append(struct( - path=src.path, - longpath=long_path(ctx, src), - webpath=webpath)) - -def _fail(ctx, message): - if ctx.attr.suppress == ["*"]: - print(message) - else: - fail(message) - -def _get_path_relative_to_package(artifact): - """Returns file path relative to the package that declared it.""" - path = artifact.path - for prefix in (artifact.root.path, - artifact.owner.workspace_root if artifact.owner else '', - artifact.owner.package if artifact.owner else ''): - if prefix: - prefix = prefix + "/" - if not path.startswith(prefix): - fail("Path %s doesn't start with %s" % (path, prefix)) - path = path[len(prefix):] - return path - -def _get_strip(ctx): - strip = ctx.attr.strip_prefix - if strip: - if strip.startswith("/"): - _fail(ctx, "strip_prefix should not end with /") - strip = strip[1:] - if strip.endswith("/"): - _fail(ctx, "strip_prefix should not end with /") - else: - strip += "/" - return strip - -web_aspect = aspect( - implementation=_web_aspect_impl, - attr_aspects=["deps", "sticky_deps", "module_deps"], - attrs={"_ClosureWorkerAspect": _CLOSURE_WORKER}) - -ts_web_library = rule( - implementation=_ts_web_library, - executable=True, - attrs=CLUTZ_ATTRIBUTES + { - "path": attr.string(), - "srcs": attr.label_list(allow_files=True), - "deps": attr.label_list( - aspects=[ - web_aspect, - clutz_aspect, - legacy_js, - ]), - "exports": attr.label_list(), - "data": attr.label_list(cfg="data", allow_files=True), - "suppress": attr.string_list(), - "strip_prefix": attr.string(), - "external_assets": attr.string_dict(default={"/_/runfiles": "."}), - "clutz_entry_points": attr.string_list(), - "_execrooter": attr.label( - default=Label("//tensorflow/tensorboard/scripts:execrooter"), - executable=True, - cfg="host"), - "_tsc": attr.label( - default=Label("@com_microsoft_typescript//:tsc"), - allow_files=True, - executable=True, - cfg="host"), - "_default_typings": attr.label( - default=Label("//tensorflow/tensorboard:ts_web_library_default_typings"), - allow_files=True), - "_WebfilesServer": attr.label( - default=Label("@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles/server:WebfilesServer"), - executable=True, - cfg="host"), - "_ClosureWorker": _CLOSURE_WORKER, - "_closure_library_base": CLOSURE_LIBRARY_BASE_ATTR, - "_closure_library_deps": CLOSURE_LIBRARY_DEPS_ATTR, - }, - outputs=CLUTZ_OUTPUTS) diff --git a/tensorflow/tensorboard/defs/zipper.bzl b/tensorflow/tensorboard/defs/zipper.bzl deleted file mode 100644 index e98309ec9a5d5185ac48e235ceb10d0d3f0e153d..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/zipper.bzl +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@io_bazel_rules_closure//closure/private:defs.bzl", "unfurl", "long_path") - -def _tensorboard_zip_file(ctx): - deps = unfurl(ctx.attr.deps, provider="webfiles") - manifests = set(order="link") - files = set() - webpaths = set() - for dep in deps: - manifests += dep.webfiles.manifests - webpaths += dep.webfiles.webpaths - files += dep.data_runfiles.files - ctx.action( - inputs=list(manifests + files), - outputs=[ctx.outputs.zip], - executable=ctx.executable._Zipper, - arguments=([ctx.outputs.zip.path] + - [m.path for m in manifests]), - progress_message="Zipping %d files" % len(webpaths)) - transitive_runfiles = set() - for dep in deps: - transitive_runfiles += dep.data_runfiles.files - return struct( - files=set([ctx.outputs.zip]), - runfiles=ctx.runfiles( - files=ctx.files.data + [ctx.outputs.zip], - transitive_files=transitive_runfiles)) - -tensorboard_zip_file = rule( - implementation=_tensorboard_zip_file, - attrs={ - "data": attr.label_list(cfg="data", allow_files=True), - "deps": attr.label_list(providers=["webfiles"], mandatory=True), - "_Zipper": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:Zipper"), - executable=True, - cfg="host"), - }, - outputs={ - "zip": "%{name}.zip", - }) diff --git a/tensorflow/tensorboard/demo/BUILD b/tensorflow/tensorboard/demo/BUILD deleted file mode 100644 index b253572ec556314356dee4911eeb755e6da18950..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") - -licenses(["notice"]) # Apache 2.0 - -# THIS PACKAGE HAS MOVED -# See tensorflow/tensorboard/components/tf_tensorboard:demo - -web_library( - name = "demo_data", - srcs = glob(["data/**"]), - path = "/", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json b/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json deleted file mode 100644 index 7dfe32c7112c61bcacf896de2d906bc06a9c952f..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"query": "index=0&tag=au1%2Faudio%2F0&run=run1", "step": 0, "wall_time": 1461795049.203407, "content_type": "audio/wav"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json b/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json deleted file mode 100644 index 13f9c2de4265d08a3b3635360d380c018f7aed7b..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"query": "index=0&tag=au2%2Faudio%2F0&run=run2", "step": 0, "wall_time": 1461795049.212815, "content_type": "audio/wav"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json deleted file mode 100644 index 6ae6fbf880e61bb8f7dfe3ed0a32dcba3e5d40cd..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -2.3150592308536755], [668, -2.0967547155036605], [1587, -1.4326244423655616], [3085, -0.8871306575801902], [5000, -0.09312398815580714], [6915, 0.2584093405812282], [8413, 0.8895470642005087], [9332, 1.3198979614453679], [10000, 1.6793308878855118]]], [100.0, 10, [[0, -1.3417572789138936], [668, -1.183563374619141], [1587, -0.48920418783271574], [3085, 0.29326906896076954], [5000, 0.56953784145381], [6915, 0.8684655583499333], [8413, 1.4133127368907181], [9332, 1.906140650457873], [10000, 2.135771998171255]]], [200.0, 20, [[0, -1.5066917525035333], [668, -1.3910909571770793], [1587, -0.902737218885874], [3085, -0.3807791904765027], [5000, 0.38900200905253046], [6915, 0.8209734209339482], [8413, 1.302385856695965], [9332, 1.9324626053521639], [10000, 2.957505317875451]]], [300.0, 30, [[0, -0.5430457051469562], [668, -0.4626161834245273], [1587, 0.21573949543027715], [3085, 0.37353741100174215], [5000, 0.6891407881591103], [6915, 1.0927156232630852], [8413, 1.2745337159550916], [9332, 1.4321116832891605], [10000, 2.1913774993059034]]], [400.0, 40, [[0, -0.3584790755077172], [668, -0.33301611509753215], [1587, -0.1089466072951948], [3085, 0.5792199847585249], [5000, 1.220854943811942], [6915, 1.759829438421432], [8413, 2.3072559906741614], [9332, 2.753036118353921], [10000, 3.0267252195784047]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json deleted file mode 100644 index 3ad520c5687cdec798b401d3740814de75d39bc8..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -3.6801669545044846], [668, -3.192188140974744], [1587, -2.3414678549368806], [3085, -0.9632173471995873], [5000, -0.3214892636797772], [6915, 0.11870794142185205], [8413, 0.8895470642005087], [9332, 1.183563374619141], [10000, 2.665663810418372]]], [100.0, 10, [[0, -3.564793583751807], [668, -3.376844436865802], [1587, -1.0366615731293798], [3085, -0.27318696312672563], [5000, 0.9718642422053263], [6915, 2.5765662807928194], [8413, 3.1415385101545126], [9332, 4.085981768607621], [10000, 4.623079406808927]]], [200.0, 20, [[0, -2.235172510433281], [668, -2.004569042815611], [1587, -1.2015432383370985], [3085, 0.11835464933202625], [5000, 0.56953784145381], [6915, 1.202844810963146], [8413, 2.689066032283515], [9332, 2.8494015726499944], [10000, 3.481377676013788]]], [300.0, 30, [[0, -3.360113978269659], [668, -2.8293185004961043], [1587, -1.5992540502266783], [3085, 0.14393860259807117], [5000, 1.47723448201245], [6915, 1.9510057389110733], [8413, 2.833176104473626], [9332, 4.142405216576347], [10000, 4.706937777668589]]], [400.0, 40, [[0, -2.599286228987632], [668, -2.240365897443259], [1587, -1.5992540502266783], [3085, -0.9101893288861387], [5000, 0.7580548669750213], [6915, 1.6009864433919474], [8413, 2.3504002974280036], [9332, 2.7907805263353733], [10000, 3.5098048900144323]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json deleted file mode 100644 index a3802ba2365adadb2453809fdf77d07ee5ef9b1f..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -1.9291158122759586], [668, -1.5970765333488954], [1587, -1.0923120348519078], [3085, -0.6688082872192093], [5000, 0.09312398815580714], [6915, 0.44532789251701854], [8413, 0.8238009655877649], [9332, 1.0357232383581656], [10000, 1.2741043689144438]]], [100.0, 10, [[0, -0.7780725642449806], [668, -0.7138496178727424], [1587, -0.5448932415735014], [3085, -0.24370397454796228], [5000, 0.42790220995778355], [6915, 0.6191730643365096], [8413, 0.752059342118037], [9332, 1.0451472255274825], [10000, 2.5559479569222825]]], [200.0, 20, [[0, -1.3876904425996377], [668, -1.1464188862638496], [1587, -0.4049955219067526], [3085, 0.04721394862139682], [5000, 0.56953784145381], [6915, 1.3221859041483333], [8413, 1.6188495656305735], [9332, 1.7613953069723651], [10000, 2.3257482385477384]]], [300.0, 30, [[0, -1.600772629982185], [668, -1.1548516185367033], [1587, -0.260387173785447], [3085, 0.17416570914366614], [5000, 0.47069243095356195], [6915, 1.1559276581637614], [8413, 2.0474031182051404], [9332, 2.18821711651116], [10000, 2.2393193406467518]]], [400.0, 40, [[0, -0.8286852465281818], [668, -0.7815041529866706], [1587, -0.3334896444053469], [3085, 0.21085213041026643], [5000, 0.5177616740489182], [6915, 1.077122434649409], [8413, 1.5898009703967424], [9332, 1.8859097291499742], [10000, 2.0954239138728523]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt b/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt deleted file mode 100644 index 2a6af3284086b4d797ebf3598bffe286d74baddf..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt +++ /dev/null @@ -1,9 +0,0 @@ -node { - name: "a" - op: "matmul" -} -node { - name: "b" - op: "matmul" - input: "a:0" -} diff --git a/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt b/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt deleted file mode 100644 index a5a4d65d5c61a7cf1c208b48f841a38a03847d60..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt +++ /dev/null @@ -1,15 +0,0 @@ -node { - name: "a" - op: "matmul" -} -node { - name: "b" - op: "matmul" - input: "a:0" -} -node { - name: "c" - op: "matmul" - input: "a:0" - input: "b:0" -} diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json deleted file mode 100644 index a5600a356e8277e58be3b2891c3e328d058b5d08..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.3584790755077172, 3.0267252195784047, 20.0, 24.012225532303315, 48.29045006426564, [-0.35363819004775493, -0.29226296698161564, -0.19961953895336082, 0.3214892636797772, 0.5177616740489182, 0.56953784145381, 0.6264916255991911, 0.7580548669750213, 0.8338603536725235, 1.220854943811942, 1.3429404381931362, 1.47723448201245, 1.624957930213695, 1.7874537232350647, 1.9661990955585713, 2.379100905625872, 2.6170109961884593, 3.1665833053880363], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json b/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json deleted file mode 100644 index 407c375d2fc710e70408a3238df3a6165e964e84..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-2.599286228987632, 3.5098048900144323, 20.0, 10.792285491200078, 66.66796979177158, [-2.379100905625872, -1.9661990955585713, -1.624957930213695, -1.47723448201245, -1.109868130738129, -1.0089710279437536, -0.42790220995778355, -0.2195814928486969, 0.47069243095356195, 0.7580548669750213, 0.917246389039776, 1.3429404381931362, 1.624957930213695, 1.7874537232350647, 2.1628190051144287, 2.6170109961884593, 2.8787120958073054, 3.8315657995195243], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 1.0, 1.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json b/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json deleted file mode 100644 index 752b621ab032f24805574708e1659c7139a701a8..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.8286852465281818, 2.0954239138728523, 20.0, 13.546880465642861, 24.14836803774091, [-0.7580548669750213, -0.38900200905253046, -0.06996543062044111, 0.07696197368248522, 0.19961953895336082, 0.2656936063469233, 0.29226296698161564, 0.5177616740489182, 0.7580548669750213, 0.917246389039776, 1.109868130738129, 1.220854943811942, 1.624957930213695, 2.1628190051144287], [2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 3.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json deleted file mode 100644 index 814b4193c638749620e86ac21b86c48747f18f4c..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.088045, "width": 4, "height": 4, "step": 0, "query": "tag=im1%2Fimage%2F0&index=0&run=run1"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json deleted file mode 100644 index 0c2bdcfc79cb32433ac987752851ef6dd351b058..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.093653, "width": 4, "height": 4, "step": 0, "query": "tag=im2%2Fimage%2F0&index=0&run=run1"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json deleted file mode 100644 index 3160aae366d904d5be5be22d60ca1b345a9d5172..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.117463, "width": 4, "height": 4, "step": 0, "query": "tag=im1%2Fimage%2F0&index=0&run=run2"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav b/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav deleted file mode 100644 index f1d24adc0cef5a734e07e8899b9abf8ae26fa228..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav b/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav deleted file mode 100644 index 006c84338f7313a225830f121bcd95f457de1708..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png deleted file mode 100644 index 346fd0076be28b9338152c4d49a32fc5ed685e44..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png deleted file mode 100644 index 26d2d10acaf8511efeb03169853092d09252215b..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png deleted file mode 100644 index 6c4190629429e0929962c4f20bd1a1602620e4bd..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/logdir b/tensorflow/tensorboard/demo/data/logdir deleted file mode 100644 index b6362b45d777266d6204b23884222a080f789f71..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/foo/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/runs.json b/tensorflow/tensorboard/demo/data/runs.json deleted file mode 100644 index e09039054299cdc3e3453c620761e1ed6e0c0169..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/runs.json +++ /dev/null @@ -1 +0,0 @@ -{"run1": {"scalars": ["foo/sin", "foo/cos", "foo/square", "bar/square"], "run_metadata": [], "compressedHistograms": ["histo1"], "images": ["im1/image/0", "im2/image/0"], "histograms": ["histo1"], "graph": true, "audio": ["au1/audio/0"]}, "run2": {"scalars": ["foo/cos", "foo/square", "bar/square"], "run_metadata": [], "compressedHistograms": ["histo2", "histo1"], "images": ["im1/image/0"], "histograms": ["histo2", "histo1"], "graph": true, "audio": ["au2/audio/0"]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars.json b/tensorflow/tensorboard/demo/data/scalars.json deleted file mode 100644 index bc269395b68a35f7d4481fca05063e46c79c2859..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars.json +++ /dev/null @@ -1 +0,0 @@ -{"run2": {"foo/cos": [[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]]}, "run1": {"foo/sin": [[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]], "foo/cos": [[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e1cd0a6a56d3d87b7183f55ac52ba6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json deleted file mode 100644 index 025eaa16e93110da0c50ad03486786ee6e521700..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json deleted file mode 100644 index eae69dd78f3b5aa75acec6b5daa08720fad9adba..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e1cd0a6a56d3d87b7183f55ac52ba6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e1cd0a6a56d3d87b7183f55ac52ba6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json deleted file mode 100644 index dd3593f9d109e81bef5a10c732a9e08e60b3ef4f..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json deleted file mode 100644 index 0ff9ef0551d0a3053ba16b502d0d6148057df660..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md deleted file mode 100644 index c2885daf93c29b5c39b68619d26623c666e28627..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/http_api.md +++ /dev/null @@ -1,402 +0,0 @@ -# Tensorboard client-server HTTP API - -## Runs, Tags, and Tag Types - -TensorBoard data is organized around the concept of a `run`, which represents -all the related data thrown off by a single execution of TensorFlow, a `tag`, -which groups values of data that come from the same source within a TensorFlow -run, and `tag types`, which are our way of distinguishing different types of -data that have fundamentally different representations and should be processed -on different code paths. For example, a "train" run may have a `scalars` -tag that represents the learning rate, another `scalars` tag that -represents the value of the objective function, a `histograms` tag that reveals -information on weights in a particular layer over time, and an `images` tag that -shows input images flowing into the system. The "eval" run might have an -entirely different set of tag names, or some duplicated tag names. - -The currently supported tag types are `scalars`, `images`, `audio`, -`histograms`, `graph` and `run_metadata`. Each tag type corresponds to a route -(documented below) for retrieving tag data of that type. - -All of the data provided comes from TensorFlow events files ('\*.tfevents\*'), -which are written using the SummaryWriter class -(tensorflow/python/training/summary_writer.py), and the data is generated by -summary ops (tensorflow/python/ops/summary_ops.py). The `scalars` come from the -`ScalarSummary` op, the `histograms` from the `HistogramSummary`, the `audio` -from the `AudioSummary`, and the `images` from `ImageSummary`. The tag type -`graph` is special in that it is not a collection of tags of that type, but a -boolean denoting if there is a graph definition associated with the run. The tag -is provided to the summary op (usually as a constant). - -## `data/logdir` - -Returns a JSON object with a key "logdir" that maps to the `logdir` argument -(string) with which Tensorboard started up. Example: -`{logdir: '/foo/logdir/argument'}` - -The `logdir` argument is the path of the directory that contains events files. - -## `data/plugins_listing` - -Returns a dict mapping from plugin name to a boolean indicating whether the -plugin is active. A plugin might be inactive, for instance, if it lacks relevant -data. Every plugin has a key. This route helps the frontend avoid issuing -requests to an inactive plugin - the routes of an inactive plugin do not work. - -## `data/runs` - -Returns an array containing the names of all the runs known to the -TensorBoard backend at this time. Each entry is a string corresponding -to a single run. - -We guarantee that as new runs are created in the log directory, they -will always appear at the end of the list returned by this route. That -is, the order of runs is persistent, and the result of this route is an -“append-only” list. - -Example response: - - ["train_run", "eval"] - -## `/data/plugin/scalars/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -scalar tags present in the corresponding run. Here is an example: - - { - "train_run": ["xent", "loss", "learning_rate"], - "eval": ["precision", "recall"] - } - -Note that runs without any scalar tags are included as keys with value the -empty array. - -## `/data/plugin/scalars/scalars?run=foo&tag=bar` - -Returns an array of event_accumulator.SimpleValueEvents ([wall_time, step, -value]) for the given run and tag. wall_time is seconds since epoch. - -Example: - - [ - [1443856985.705543, 1448, 0.7461960315704346], # wall_time, step, value - [1443857105.704628, 3438, 0.5427092909812927], - [1443857225.705133, 5417, 0.5457325577735901], - ... - ] - -If the format parameter is set to 'csv', the response will instead be in CSV -format: - - Wall time,step,value - 1443856985.705543,1448,0.7461960315704346 - 1443857105.704628,3438,0.5427092909812927 - 1443857225.705133,5417,0.5457325577735901 - -## `/data/plugin/histograms/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -histogram tags present in the corresponding run. Here is an example: - - { - "train_run": ["foo_histogram", "bar_histogram"], - "eval": ["foo_histogram", "bar_histogram"] - } - -Note that runs without any histogram tags are included as keys with -value the empty array. - -## `/data/plugin/histograms/histograms?run=foo&tag=bar` - -Returns an array of event_accumulator.HistogramEvents ([wall_time, step, -HistogramValue]) for the given run and tag. A HistogramValue is [min, max, num, -sum, sum_squares, bucket_limit, bucket]. wall_time is seconds since epoch. - -Annotated Example: (note - real data is higher precision) - - [ - [ - 1443871386.185149, # wall_time - 235166, # step - [ - -0.66, # minimum value - 0.44, # maximum value - 8.0, # number of items in the histogram - -0.80, # sum of items in the histogram - 0.73, # sum of squares of items in the histogram - [-0.68, -0.62, -0.292, -0.26, -0.11, -0.10, -0.08, -0.07, -0.05, - -0.0525, -0.0434, -0.039, -0.029, -0.026, 0.42, 0.47, 1.8e+308], - # the right edge of each bucket - [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, - 1.0, 0.0] # the number of elements within each bucket - ] - ] - ] - -## `/data/plugin/distributions/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -distribution tags present in the corresponding run. Here is an example: - - { - "train_run": ["foo_histogram", "bar_histogram"], - "eval": ["foo_histogram", "bar_histogram"] - } - -Note that runs without any distribution tags are included as keys with -value the empty array. - -## `/data/plugin/distributions/distributions?run=foo&tag=bar` - -Returns an array of event_accumulator.CompressedHistogramEvents ([wall_time, -step, CompressedHistogramValues]) for the given run and tag. - -CompressedHistogramValues is a list of namedtuples with each tuple specifying -a basis point (bps) as well as an interpolated value of the histogram value -at that basis point. A basis point is 1/100 of a percent. - -The current compression strategy is to choose basis points that correspond to -the median and bands of 1SD, 2SD, and 3SDs around the median. Note that the -current compression strategy does not work well for representing multimodal -data -- this is something that will be improved in a later iteration. - -Annotated Example: (note - real data is higher precision) - - [ - [ - 1441154832.580509, # wall_time - 5, # step - [ [0, -3.67], # CompressedHistogramValue for 0th percentile - [2500, -4.19], # CompressedHistogramValue for 25th percentile - [5000, 6.29], - [7500, 1.64], - [10000, 3.67] - ] - ], - ... - ] - -## `/data/plugin/images/images?run=foo&tag=bar` - -Gets a sample of ImageMetadatas for the given run and tag. - -Returns an array of objects containing information about available images, -crucially including the query parameter that may be used to retrieve that image. -(See /data/plugin/images/individualImage for details.) - -For example: - - { - "width": 28, # width in pixels - "height": 28, # height in pixels - "wall_time": 1440210599.246, # time in seconds since epoch - "step": 63702821, # number of steps that have passed - "query": "index=0&tagname=input%2Fimage%2F2&run=train" - # param for /individualImage - } - -## `/data/plugin/images/individualImage?{{query}}` - -Retrieves an individual image. The image query should not be generated by the -frontend, but instead acquired from calling the /images route (the image -metadata objects contain the query to use). The response is the image itself -with mime-type 'image/png'. - -Note that the query is not guaranteed to always refer to the same image even -within a single run, as images may be removed from the sampling reservoir and -replaced with other images. (See Notes for details on the reservoir sampling.) - -An example call to this route would look like this: -/data/plugin/images/individualImage?index=0&tagname=input%2Fimage%2F2&run=train - -## `/data/plugin/images/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all image -tags present in the corresponding run. Here is an example: - - { - "train": ["foo_image", "bar_image"], - "eval": ["foo_image", "bar_image"] - } - -Note that runs without any image tags are included as keys with value the empty -array. - -## `/data/plugin/audio/audio?run=foo&tag=bar` - -Gets a sample of AudioMetadatas for the given run and tag. - -Returns an array of objects containing information about available audio, -crucially including the query parameter that may be used to retrieve that audio. -(See /data/plugin/audio/individualAudio for details.) - -For example: - - { - "wall_time": 1440210599.246, # time in seconds since epoch - "step": 63702821, # number of steps that have passed - "content_type": "audio/wav" # the MIME-type of the audio - "query": "index=0&tagname=input%2Faudio%2F2&run=train" - # param for /individualAudio - } - -## `/data/plugin/audio/individualAudio?{{query}}` - -Retrieves an individual audio clip. The audio query should not be generated by -the frontend, but instead acquired from calling the /audio route (the audio -metadata objects contain the query to use). The response is the audio itself -with an appropriate Content-Type header set. - -Note that the query is not guaranteed to always refer to the same clip even -within a single run, as audio may be removed from the sampling reservoir and -replaced with other clips. (See Notes for details on the reservoir sampling.) - -An example call to this route would look like this: -/individualAudio?index=0&tagname=input%2Faudio%2F2&run=train - -## `/data/plugin/audio/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all audio -tags present in the corresponding run. Here is an example: - - { - "train": ["foo_audio", "bar_audio"], - "eval": ["foo_audio", "bar_audio"], - } - -Note that runs without any audio tags are included as keys with value the empty -array. - -## `/data/plugin/graphs/runs` - -Returns a list of runs that have associated graphs. - -For example: - - ["train"] - -## `/data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` - -Returns the graph definition for the given run in pbtxt format. The -graph is composed of a list of nodes, where each node is a specific -TensorFlow operation which takes as inputs other nodes (operations). - -The query parameters `limit_attr_size` and `large_attrs_key` are optional. - -`limit_attr_size` specifies the maximum allowed size in bytes, before the -attribute is considered large and filtered out of the graph. If specified, -it must be an int and > 0. If not specified, no filtering is applied. - -`large_attrs_key` is the attribute key that will be used for storing -attributes that are too large. The value of this key (list of strings) -should be used by the client in order to determine which attributes -have been filtered. Must be specified if `limit_attr_size` is specified. - -For the query - - /data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large, - -here is an example pbtxt response of a graph with 3 nodes, where the second -node had two large attributes "a" and "b" that were filtered out (size > 1024): - - node { - op: "Input" - name: "A" - } - node { - op: "Input" - name: "B" - attr { - key: "small_attr" - value: { - s: "some string" - } - } - attr { - key: "_too_large" - value { - list { - s: "a" - s: "b" - } - } - } - } - node { - op: "MatMul" - name: "C" - input: "A" - input: "B" - } - -Prior to filtering, the original node "B" had the following content: - - node { - op: "Input" - name: "B" - attr { - key: "small_attr" - value: { - s: "some string" - } - } - attr { - key: "a" - value { Very large object... } - } - attr { - key: "b" - value { Very large object... } - } - } - -## `/data/run_metadata?run=foo&tag=bar` - -Given a run and tag, returns the metadata of a particular -`session.run()` as a gzipped, pbtxt serialized [`RunMetadata`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto) -proto. For example: - - step_stats { - dev_stats { - device: "/job:localhost/replica:0/task:0/cpu:0" - node_stats { - node_name: "_SOURCE" - all_start_micros: 1458337695775395 - op_start_rel_micros: 11 - op_end_rel_micros: 12 - all_end_rel_micros: 38 - memory { - allocator_name: "cpu" - } - timeline_label: "_SOURCE = NoOp()" - scheduled_micros: 1458337695775363 - } - } - } - -## Notes - -All returned values, histograms, audio, and images are returned in the order -they were written by TensorFlow (which should correspond to increasing -`wall_time` order, but may not necessarily correspond to increasing step count -if the process had to restart from a previous checkpoint). - -The returned values may be downsampled using reservoir sampling, which is -configurable by the TensorBoard server. When downsampling occurs, the server -guarantees that different tags will all sample at the same sequence of indices, -so that if there are two tags `A` and `B` which are related so that `A[i] ~ -B[i]` for all `i`, then `D(A)[i] ~ D(B)[i]` for all `i`, where `D` represents -the downsampling operation. - -The reservoir sampling puts an upper bound on the number of items that will be -returned for a given run-tag combination, and guarantees that all items are -equally likely to be in the final sample (ie it is a uniform distribution over -the values), with the proviso that the most recent individual item is always -included in the sample. - -The reservoir sizes are configurable on a per-tag type basis. diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD deleted file mode 100644 index f1f7746ff846e549f3473412470bbff3970a7741..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -java_binary( - name = "Vulcanize", - srcs = ["Vulcanize.java"], - jvm_flags = [ - "-Xss20m", # JSCompiler needs big stacks for recursive parsing - "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive - "-Djava.util.logging.SimpleFormatter.format='%1$$tY-%1$$tm-%1$$td %1$$tH:%1$$tM:%1$$tS.%1$$tL %4$$-6s %5$$s%6$$s%n'", # Less log spam - ], - visibility = ["//visibility:public"], - deps = [ - "@com_google_guava", - "@com_google_protobuf_java", - "@io_bazel_rules_closure//closure/compiler", - "@io_bazel_rules_closure//java/io/bazel/rules/closure:webpath", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles:build_info_java_proto", - "@io_bazel_rules_closure//java/org/jsoup/nodes", - "@org_jsoup", - ], -) - -java_binary( - name = "Zipper", - srcs = ["Zipper.java"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_guava", - "@com_google_protobuf_java", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles:build_info_java_proto", - ], -) - -# These JS files are always taken into consideration by the Closure Compiler -# when vulcanizing, per vulcanize.bzl. -filegroup( - name = "jslibs", - srcs = [ - # Ordering probably matters - "@com_google_javascript_closure_compiler_externs", - "@com_google_javascript_closure_compiler_externs_polymer", - "externs.js", - "@com_google_javascript_closure_library//:closure/goog/base.js", - "@com_google_javascript_closure_library//:closure/goog/deps.js", - ], - visibility = ["//visibility:public"], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java deleted file mode 100644 index 533907dd64dd84107d46dd7411235c4ff8aaa755..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java +++ /dev/null @@ -1,546 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.tensorflow.tensorboard.vulcanize; - -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Verify.verify; -import static com.google.common.base.Verify.verifyNotNull; -import static java.nio.charset.StandardCharsets.UTF_8; - -import com.google.common.base.CharMatcher; -import com.google.common.base.Joiner; -import com.google.common.base.Optional; -import com.google.common.base.Splitter; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; -import com.google.common.collect.Multimap; -import com.google.javascript.jscomp.CheckLevel; -import com.google.javascript.jscomp.CompilationLevel; -import com.google.javascript.jscomp.Compiler; -import com.google.javascript.jscomp.CompilerOptions; -import com.google.javascript.jscomp.DiagnosticGroup; -import com.google.javascript.jscomp.DiagnosticGroups; -import com.google.javascript.jscomp.DiagnosticType; -import com.google.javascript.jscomp.JSError; -import com.google.javascript.jscomp.ModuleIdentifier; -import com.google.javascript.jscomp.PropertyRenamingPolicy; -import com.google.javascript.jscomp.Result; -import com.google.javascript.jscomp.SourceFile; -import com.google.javascript.jscomp.WarningsGuard; -import com.google.protobuf.TextFormat; -import io.bazel.rules.closure.Webpath; -import io.bazel.rules.closure.webfiles.BuildInfo.Webfiles; -import io.bazel.rules.closure.webfiles.BuildInfo.WebfilesSource; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Deque; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import org.jsoup.Jsoup; -import org.jsoup.nodes.Attribute; -import org.jsoup.nodes.Comment; -import org.jsoup.nodes.DataNode; -import org.jsoup.nodes.Document; -import org.jsoup.nodes.Element; -import org.jsoup.nodes.Html5Printer; -import org.jsoup.nodes.Node; -import org.jsoup.nodes.TextNode; -import org.jsoup.parser.Parser; -import org.jsoup.parser.Tag; - -/** Simple one-off solution for TensorBoard vulcanization. */ -public final class Vulcanize { - - private static final Pattern IGNORE_PATHS_PATTERN = - Pattern.compile("/(?:polymer|marked-element)/.*"); - - private static final ImmutableSet EXTRA_JSDOC_TAGS = - ImmutableSet.of("attribute", "hero", "group", "required"); - - private static final Pattern WEBPATH_PATTERN = Pattern.compile("//~~WEBPATH~~([^\n]+)"); - - private static final Parser parser = Parser.htmlParser(); - private static final Map webfiles = new HashMap<>(); - private static final Set alreadyInlined = new HashSet<>(); - private static final Set legalese = new HashSet<>(); - private static final List licenses = new ArrayList<>(); - private static final List stack = new ArrayList<>(); - private static final List externs = new ArrayList<>(); - private static final List sourcesFromJsLibraries = new ArrayList<>(); - private static final Map sourcesFromScriptTags = new LinkedHashMap<>(); - private static final Map sourceTags = new LinkedHashMap<>(); - private static final Multimap suppressions = HashMultimap.create(); - private static CompilationLevel compilationLevel; - private static Webpath outputPath; - private static Node firstCompiledScript; - private static Node licenseComment; - private static int insideDemoSnippet; - private static boolean testOnly; - - public static void main(String[] args) throws IOException { - compilationLevel = CompilationLevel.fromString(args[0]); - testOnly = args[1].equals("true"); - Webpath inputPath = Webpath.get(args[2]); - outputPath = Webpath.get(args[3]); - Path output = Paths.get(args[4]); - for (int i = 5; i < args.length; i++) { - if (args[i].endsWith(".js")) { - String code = new String(Files.readAllBytes(Paths.get(args[i])), UTF_8); - SourceFile sourceFile = SourceFile.fromCode(args[i], code); - if (code.contains("@externs")) { - externs.add(sourceFile); - } else { - sourcesFromJsLibraries.add(sourceFile); - } - continue; - } - if (!args[i].endsWith(".pbtxt")) { - continue; - } - Webfiles manifest = loadWebfilesPbtxt(Paths.get(args[i])); - for (WebfilesSource src : manifest.getSrcList()) { - webfiles.put(Webpath.get(src.getWebpath()), Paths.get(src.getPath())); - } - } - stack.add(inputPath); - Document document = parse(Files.readAllBytes(webfiles.get(inputPath))); - transform(document); - compile(); - if (licenseComment != null) { - licenseComment.attr("comment", String.format("\n%s\n", Joiner.on("\n\n").join(licenses))); - } - Files.write( - output, - Html5Printer.stringify(document).getBytes(UTF_8), - StandardOpenOption.WRITE, - StandardOpenOption.CREATE, - StandardOpenOption.TRUNCATE_EXISTING); - } - - private static void transform(Node root) throws IOException { - Node node = checkNotNull(root); - Node newNode; - while (true) { - newNode = enterNode(node); - if (node.equals(root)) { - root = newNode; - } - node = newNode; - if (node.childNodeSize() > 0) { - node = node.childNode(0); - } else { - while (true) { - newNode = leaveNode(node); - if (node.equals(root)) { - root = newNode; - } - node = newNode; - if (node.equals(root)) { - return; - } - Node next = node.nextSibling(); - if (next == null) { - if (node.parentNode() == null) { - return; - } - node = verifyNotNull(node.parentNode(), "unexpected root: %s", node); - } else { - node = next; - break; - } - } - } - } - } - - private static Node enterNode(Node node) throws IOException { - if (node.nodeName().equals("demo-snippet")) { - insideDemoSnippet++; - } - if (insideDemoSnippet > 0) { - return node; - } - if (node instanceof Element) { - if (!getAttrTransitive(node, "vulcanize-noinline").isPresent()) { - if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { - // Inline HTML. - node = visitHtmlImport(node); - } else if (node.nodeName().equals("script") - && !shouldIgnoreUri(node.attr("src")) - && !node.hasAttr("jscomp-ignore")) { - node = visitScript(node); - } else if (node.nodeName().equals("link") - && node.attr("rel").equals("stylesheet") - && !node.attr("href").isEmpty() - && !shouldIgnoreUri(node.attr("href"))) { - node = visitStylesheet(node); - } - } - rootifyAttribute(node, "href"); - rootifyAttribute(node, "src"); - rootifyAttribute(node, "action"); - rootifyAttribute(node, "assetpath"); - } else if (node instanceof Comment) { - String text = ((Comment) node).getData(); - if (text.contains("@license")) { - handleLicense(text); - if (licenseComment == null) { - licenseComment = node; - } else { - node = replaceNode(node, new TextNode("", node.baseUri())); - } - } else { - node = replaceNode(node, new TextNode("", node.baseUri())); - } - } - return node; - } - - private static Node leaveNode(Node node) { - if (node instanceof Document) { - stack.remove(stack.size() - 1); - } else if (node.nodeName().equals("demo-snippet")) { - insideDemoSnippet--; - } - return node; - } - - private static Node visitHtmlImport(Node node) throws IOException { - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - if (alreadyInlined.add(href)) { - stack.add(href); - Document subdocument = parse(Files.readAllBytes(getWebfile(href))); - for (Attribute attr : node.attributes()) { - subdocument.attr(attr.getKey(), attr.getValue()); - } - return replaceNode(node, subdocument); - } else { - return replaceNode(node, new TextNode("", node.baseUri())); - } - } - - private static Node visitScript(Node node) throws IOException { - Webpath path; - String script; - if (node.attr("src").isEmpty()) { - path = makeSyntheticName(".js"); - script = getInlineScriptFromNode(node); - } else { - path = me().lookup(Webpath.get(node.attr("src"))); - script = new String(Files.readAllBytes(getWebfile(path)), UTF_8); - } - if (node.attr("src").endsWith(".min.js") - || getAttrTransitive(node, "jscomp-nocompile").isPresent()) { - Node newScript = - new Element(Tag.valueOf("script"), node.baseUri(), node.attributes()) - .appendChild(new DataNode(script, node.baseUri())) - .removeAttr("src") - .removeAttr("jscomp-nocompile"); - if (firstCompiledScript != null) { - firstCompiledScript.before(newScript); - return replaceNode(node, new TextNode("", node.baseUri())); - } else { - return replaceNode(node, newScript); - } - } else { - if (firstCompiledScript == null) { - firstCompiledScript = node; - } - sourcesFromScriptTags.put(path, script); - sourceTags.put(path, node); - Optional suppress = getAttrTransitive(node, "jscomp-suppress"); - if (suppress.isPresent()) { - if (suppress.get().isEmpty()) { - suppressions.put(path, "*"); - } else { - suppressions.putAll(path, Splitter.on(' ').split(suppress.get())); - } - } - return node; - } - } - - private static Node visitStylesheet(Node node) throws IOException { - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - return replaceNode( - node, - new Element(Tag.valueOf("style"), node.baseUri(), node.attributes()) - .appendChild( - new DataNode( - new String(Files.readAllBytes(getWebfile(href)), UTF_8), node.baseUri())) - .removeAttr("rel") - .removeAttr("href")); - } - - private static Optional getAttrTransitive(Node node, String attr) { - while (node != null) { - if (node.hasAttr(attr)) { - return Optional.of(node.attr(attr)); - } - node = node.parent(); - } - return Optional.absent(); - } - - private static Node replaceNode(Node oldNode, Node newNode) { - oldNode.replaceWith(newNode); - return newNode; - } - - private static Path getWebfile(Webpath path) { - return verifyNotNull(webfiles.get(path), "Bad ref: %s -> %s", me(), path); - } - - private static void compile() { - if (sourcesFromScriptTags.isEmpty()) { - return; - } - - CompilerOptions options = new CompilerOptions(); - compilationLevel.setOptionsForCompilationLevel(options); - - // Nice options. - options.setColorizeErrorOutput(true); - options.setContinueAfterErrors(true); - options.setLanguageIn(CompilerOptions.LanguageMode.ECMASCRIPT_2016); - options.setLanguageOut(CompilerOptions.LanguageMode.ECMASCRIPT5); - options.setGenerateExports(true); - options.setStrictModeInput(false); - options.setExtraAnnotationNames(EXTRA_JSDOC_TAGS); - - // So we can chop JS binary back up into the original script tags. - options.setPrintInputDelimiter(true); - options.setInputDelimiter("//~~WEBPATH~~%name%"); - - // Optimizations that are too advanced for us right now. - options.setPropertyRenaming(PropertyRenamingPolicy.OFF); - options.setCheckGlobalThisLevel(CheckLevel.OFF); - options.setRemoveUnusedPrototypeProperties(false); - options.setRemoveUnusedPrototypePropertiesInExterns(false); - options.setRemoveUnusedClassProperties(false); - - // Dependency management. - options.setClosurePass(true); - options.setManageClosureDependencies(true); - options.getDependencyOptions().setDependencyPruning(true); - options.getDependencyOptions().setDependencySorting(true); - options.getDependencyOptions().setMoocherDropping(false); - options.getDependencyOptions() - .setEntryPoints( - sourceTags - .keySet() - .stream() - .map(Webpath::toString) - .map(ModuleIdentifier::forFile) - .collect(Collectors.toList())); - - // Polymer pass. - options.setPolymerVersion(1); - - // Debug flags. - if (testOnly) { - options.setPrettyPrint(true); - options.setGeneratePseudoNames(true); - options.setExportTestFunctions(true); - } - - // Don't print warnings from " - sanitized = "<script>alert('xss')</script>" - self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) - - dangerous = textwrap.dedent("""\ - hello *you*""") - sanitized = '

hello you

' - self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) - - def testTableGeneration(self): - array2d = np.array([['one', 'two'], ['three', 'four']]) - expected_table = textwrap.dedent("""\ - - - - - - - - - - - -
onetwo
threefour
""") - self.assertEqual(text_plugin.make_table(array2d), expected_table) - - expected_table_with_headers = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - -
c1c2
onetwo
threefour
""") - - actual_with_headers = text_plugin.make_table(array2d, headers=['c1', 'c2']) - self.assertEqual(actual_with_headers, expected_table_with_headers) - - array_1d = np.array(['one', 'two', 'three', 'four', 'five']) - expected_1d = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - - -
one
two
three
four
five
""") - self.assertEqual(text_plugin.make_table(array_1d), expected_1d) - - expected_1d_with_headers = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - - - - - - - -
X
one
two
three
four
five
""") - actual_1d_with_headers = text_plugin.make_table(array_1d, headers=['X']) - self.assertEqual(actual_1d_with_headers, expected_1d_with_headers) - - def testMakeTableExceptions(self): - # Verify that contents is being type-checked and shape-checked. - with self.assertRaises(ValueError): - text_plugin.make_table([]) - - with self.assertRaises(ValueError): - text_plugin.make_table('foo') - - with self.assertRaises(ValueError): - invalid_shape = np.full((3, 3, 3), 'nope', dtype=np.dtype('S3')) - text_plugin.make_table(invalid_shape) - - # Test headers exceptions in 2d array case. - test_array = np.full((3, 3), 'foo', dtype=np.dtype('S3')) - with self.assertRaises(ValueError): - # Headers is wrong type. - text_plugin.make_table(test_array, headers='foo') - with self.assertRaises(ValueError): - # Too many headers. - text_plugin.make_table(test_array, headers=['foo', 'bar', 'zod', 'zoink']) - with self.assertRaises(ValueError): - # headers is 2d - text_plugin.make_table(test_array, headers=test_array) - - # Also make sure the column counting logic works in the 1d array case. - test_array = np.array(['foo', 'bar', 'zod']) - with self.assertRaises(ValueError): - # Too many headers. - text_plugin.make_table(test_array, headers=test_array) - - def test_reduce_to_2d(self): - - def make_range_array(dim): - """Produce an incrementally increasing multidimensional array. - - Args: - dim: the number of dimensions for the array - - Returns: - An array of increasing integer elements, with dim dimensions and size - two in each dimension. - - Example: rangeArray(2) results in [[0,1],[2,3]]. - """ - return np.array(range(2**dim)).reshape([2] * dim) - - for i in range(2, 5): - actual = text_plugin.reduce_to_2d(make_range_array(i)) - expected = make_range_array(2) - np.testing.assert_array_equal(actual, expected) - - def test_text_array_to_html(self): - - convert = text_plugin.text_array_to_html - scalar = np.array('foo') - scalar_expected = '

foo

' - self.assertEqual(convert(scalar), scalar_expected) - - vector = np.array(['foo', 'bar']) - vector_expected = textwrap.dedent("""\ - - - - - - - - - -

foo

bar

""") - self.assertEqual(convert(vector), vector_expected) - - d2 = np.array([['foo', 'bar'], ['zoink', 'zod']]) - d2_expected = textwrap.dedent("""\ - - - - - - - - - - - -

foo

bar

zoink

zod

""") - self.assertEqual(convert(d2), d2_expected) - - d3 = np.array([[['foo', 'bar'], ['zoink', 'zod']], [['FOO', 'BAR'], - ['ZOINK', 'ZOD']]]) - - warning = text_plugin.markdown_and_sanitize(text_plugin.WARNING_TEMPLATE % - 3) - d3_expected = warning + textwrap.dedent("""\ - - - - - - - - - - - -

foo

bar

zoink

zod

""") - self.assertEqual(convert(d3), d3_expected) - - def testPluginIsActive(self): - plugin = text_plugin.TextPlugin() - multiplexer = event_multiplexer.EventMultiplexer() - plugin.get_plugin_apps(event_multiplexer.EventMultiplexer(), None) - - # The plugin is inactive because text summaries are not available. - self.assertFalse(plugin.is_active()) - - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - - # The plugin is active because text summaries are available. - self.assertTrue(self.plugin.is_active()) - - def testUnicode(self): - self.assertConverted(u'

Iñtërnâtiônàlizætiøn⚡💩

', - 'Iñtërnâtiônàlizætiøn⚡💩') - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/scripts/BUILD b/tensorflow/tensorboard/scripts/BUILD deleted file mode 100644 index 05425ee61d05e3a0e540106a8c313205562b347c..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/scripts/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -# Description: -# Some useful scripts that are bundled with TensorBoard. - -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_binary( - name = "generate_testdata", - srcs = ["generate_testdata.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_binary( - name = "execrooter", - srcs = ["execrooter.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], -) - -filegroup( - name = "all_files", - srcs = glob(["*"]), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/tensorboard/scripts/execrooter.py b/tensorflow/tensorboard/scripts/execrooter.py deleted file mode 100644 index 65569b9151258dc692ec45223a4f9118ea803126..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/scripts/execrooter.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utility for running programs in a symlinked execroot.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import os -import shutil -import subprocess -import sys -import tempfile - - -def run(inputs, program, outputs): - """Creates temp symlink tree, runs program, and copies back outputs. - - Args: - inputs: List of fake paths to real paths, which are used for symlink tree. - program: List containing real path of program and its arguments. The - execroot directory will be appended as the last argument. - outputs: List of fake outputted paths to copy back to real paths. - Returns: - 0 if succeeded or nonzero if failed. - """ - root = tempfile.mkdtemp() - try: - cwd = os.getcwd() - for fake, real in inputs: - parent = os.path.join(root, os.path.dirname(fake)) - if not os.path.exists(parent): - os.makedirs(parent) - os.symlink(os.path.join(cwd, real), os.path.join(root, fake)) - if subprocess.call(program + [root]) != 0: - return 1 - for fake, real in outputs: - shutil.copyfile(os.path.join(root, fake), real) - return 0 - finally: - shutil.rmtree(root) - - -def main(args): - """Invokes run function using a JSON file config. - - Args: - args: CLI args, which can be a JSON file containing an object whose - attributes are the parameters to the run function. If multiple JSON - files are passed, their contents are concatenated. - Returns: - 0 if succeeded or nonzero if failed. - Raises: - Exception: If input data is missing. - """ - if not args: - raise Exception('Please specify at least one JSON config path') - inputs = [] - program = [] - outputs = [] - for arg in args: - with open(arg) as fd: - config = json.load(fd) - inputs.extend(config.get('inputs', [])) - program.extend(config.get('program', [])) - outputs.extend(config.get('outputs', [])) - if not program: - raise Exception('Please specify a program') - return run(inputs, program, outputs) - - -if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) diff --git a/tensorflow/tensorboard/scripts/generate_testdata.py b/tensorflow/tensorboard/scripts/generate_testdata.py deleted file mode 100644 index f191d16a82dc9f771ea4f1d42a510625c157d119..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/scripts/generate_testdata.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Generate some standard test data for debugging TensorBoard. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import bisect -import math -import os -import os.path -import random -import shutil - -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf - - -tf.flags.DEFINE_string("target", None, """The directory where serialized data -will be written""") - -tf.flags.DEFINE_boolean("overwrite", False, """Whether to remove and overwrite -TARGET if it already exists.""") - -FLAGS = tf.flags.FLAGS - -# Hardcode a start time and reseed so script always generates the same data. -_start_time = 0 -random.seed(0) - - -def _MakeHistogramBuckets(): - v = 1E-12 - buckets = [] - neg_buckets = [] - while v < 1E20: - buckets.append(v) - neg_buckets.append(-v) - v *= 1.1 - # Should include DBL_MAX, but won't bother for test data. - return neg_buckets[::-1] + [0] + buckets - - -def _MakeHistogram(values): - """Convert values into a histogram proto using logic from histogram.cc.""" - limits = _MakeHistogramBuckets() - counts = [0] * len(limits) - for v in values: - idx = bisect.bisect_left(limits, v) - counts[idx] += 1 - - limit_counts = [(limits[i], counts[i]) for i in xrange(len(limits)) - if counts[i]] - bucket_limit = [lc[0] for lc in limit_counts] - bucket = [lc[1] for lc in limit_counts] - sum_sq = sum(v * v for v in values) - return tf.HistogramProto( - min=min(values), - max=max(values), - num=len(values), - sum=sum(values), - sum_squares=sum_sq, - bucket_limit=bucket_limit, - bucket=bucket) - - -def WriteScalarSeries(writer, tag, f, n=5): - """Write a series of scalar events to writer, using f to create values.""" - step = 0 - wall_time = _start_time - for i in xrange(n): - v = f(i) - value = tf.Summary.Value(tag=tag, simple_value=v) - summary = tf.Summary(value=[value]) - event = tf.Event(wall_time=wall_time, step=step, summary=summary) - writer.add_event(event) - step += 1 - wall_time += 10 - - -def WriteHistogramSeries(writer, tag, mu_sigma_tuples, n=20): - """Write a sequence of normally distributed histograms to writer.""" - step = 0 - wall_time = _start_time - for [mean, stddev] in mu_sigma_tuples: - data = [random.normalvariate(mean, stddev) for _ in xrange(n)] - histo = _MakeHistogram(data) - summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)]) - event = tf.Event(wall_time=wall_time, step=step, summary=summary) - writer.add_event(event) - step += 10 - wall_time += 100 - - -def WriteImageSeries(writer, tag, n_images=1): - """Write a few dummy images to writer.""" - step = 0 - session = tf.Session() - p = tf.placeholder("uint8", (1, 4, 4, 3)) - s = tf.summary.image(tag, p) - for _ in xrange(n_images): - im = np.random.random_integers(0, 255, (1, 4, 4, 3)) - summ = session.run(s, feed_dict={p: im}) - writer.add_summary(summ, step) - step += 20 - session.close() - - -def WriteAudioSeries(writer, tag, n_audio=1): - """Write a few dummy audio clips to writer.""" - step = 0 - session = tf.Session() - - min_frequency_hz = 440 - max_frequency_hz = 880 - sample_rate = 4000 - duration_frames = sample_rate // 2 # 0.5 seconds. - frequencies_per_run = 1 - num_channels = 2 - - p = tf.placeholder("float32", (frequencies_per_run, duration_frames, - num_channels)) - s = tf.summary.audio(tag, p, sample_rate) - - for _ in xrange(n_audio): - # Generate a different frequency for each channel to show stereo works. - frequencies = np.random.random_integers( - min_frequency_hz, - max_frequency_hz, - size=(frequencies_per_run, num_channels)) - tiled_frequencies = np.tile(frequencies, (1, duration_frames)) - tiled_increments = np.tile( - np.arange(0, duration_frames), - (num_channels, 1)).T.reshape(1, duration_frames * num_channels) - tones = np.sin(2.0 * np.pi * tiled_frequencies * tiled_increments / - sample_rate) - tones = tones.reshape(frequencies_per_run, duration_frames, num_channels) - - summ = session.run(s, feed_dict={p: tones}) - writer.add_summary(summ, step) - step += 20 - session.close() - - -def GenerateTestData(path): - """Generates the test data directory.""" - run1_path = os.path.join(path, "run1") - os.makedirs(run1_path) - writer1 = tf.summary.FileWriter(run1_path) - WriteScalarSeries(writer1, "foo/square", lambda x: x * x) - WriteScalarSeries(writer1, "bar/square", lambda x: x * x) - WriteScalarSeries(writer1, "foo/sin", math.sin) - WriteScalarSeries(writer1, "foo/cos", math.cos) - WriteHistogramSeries(writer1, "histo1", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], - [1, 1]]) - WriteImageSeries(writer1, "im1") - WriteImageSeries(writer1, "im2") - WriteAudioSeries(writer1, "au1") - - run2_path = os.path.join(path, "run2") - os.makedirs(run2_path) - writer2 = tf.summary.FileWriter(run2_path) - WriteScalarSeries(writer2, "foo/square", lambda x: x * x * 2) - WriteScalarSeries(writer2, "bar/square", lambda x: x * x * 3) - WriteScalarSeries(writer2, "foo/cos", lambda x: math.cos(x) * 2) - WriteHistogramSeries(writer2, "histo1", [[0, 2], [0.3, 2], [0.5, 2], [0.7, 2], - [1, 2]]) - WriteHistogramSeries(writer2, "histo2", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], - [1, 1]]) - WriteImageSeries(writer2, "im1") - WriteAudioSeries(writer2, "au2") - - graph_def = tf.GraphDef() - node1 = graph_def.node.add() - node1.name = "a" - node1.op = "matmul" - node2 = graph_def.node.add() - node2.name = "b" - node2.op = "matmul" - node2.input.extend(["a:0"]) - - writer1.add_graph(graph_def) - node3 = graph_def.node.add() - node3.name = "c" - node3.op = "matmul" - node3.input.extend(["a:0", "b:0"]) - writer2.add_graph(graph_def) - writer1.close() - writer2.close() - - -def main(unused_argv=None): - target = FLAGS.target - if not target: - print("The --target flag is required.") - return -1 - if os.path.exists(target): - if FLAGS.overwrite: - if os.path.isdir(target): - shutil.rmtree(target) - else: - os.remove(target) - else: - print("Refusing to overwrite target %s without --overwrite" % target) - return -2 - GenerateTestData(target) - - -if __name__ == "__main__": - tf.app.run() diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d01342827dc26a80ac0d7f829c4e093afcf76abb..d31b9a6e3141b1a30055e3c7cc94c8aec9633675 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -72,6 +72,13 @@ def if_android_arm64(a): }) +def if_android_mips(a): + return select({ + clean_dep("//tensorflow:android_mips"): a, + "//conditions:default": [], + }) + + def if_not_android(a): return select({ clean_dep("//tensorflow:android"): [], @@ -79,6 +86,14 @@ def if_not_android(a): }) +def if_not_android_mips_and_mips64(a): + return select({ + clean_dep("//tensorflow:android_mips"): [], + clean_dep("//tensorflow:android_mips64"): [], + "//conditions:default": a, + }) + + def if_android(a): return select({ clean_dep("//tensorflow:android"): a, @@ -117,11 +132,9 @@ def if_not_windows(a): }) -def if_x86(a): +def if_linux_x86_64(a): return select({ clean_dep("//tensorflow:linux_x86_64"): a, - clean_dep("//tensorflow:windows"): a, - clean_dep("//tensorflow:windows_msvc"): a, "//conditions:default": [], }) @@ -138,17 +151,21 @@ WIN_COPTS = [ "/DTF_COMPILE_LIBRARY", "/DEIGEN_HAS_C99_MATH", "/DTENSORFLOW_USE_EIGEN_THREADPOOL", + "/DEIGEN_AVOID_STL_ARRAY", + "/Iexternal/gemmlowp", + "/wd4018", # -Wno-sign-compare + "/U_HAS_EXCEPTIONS", "/D_HAS_EXCEPTIONS=1", "/EHsc", # -fno-exceptions ] # LINT.IfChange def tf_copts(): - return ([ + return (if_not_windows([ "-DEIGEN_AVOID_STL_ARRAY", "-Iexternal/gemmlowp", "-Wno-sign-compare", "-fno-exceptions", - ] + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm( - ["-mfpu=neon"]) + if_x86(["-msse3"]) + select({ + ]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm( + ["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + select({ clean_dep("//tensorflow:android"): [ "-std=c++11", "-DTF_LEAN_BINARY", @@ -167,7 +184,7 @@ def tf_opts_nortti_if_android(): "-fno-rtti", "-DGOOGLE_PROTOBUF_NO_RTTI", "-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER", - ]) + if_android_x86(["-msse4.1"]) + ]) # LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) @@ -851,7 +868,7 @@ def cc_header_only_library(name, deps=[], **kwargs): def tf_custom_op_library_additional_deps(): return [ - "@protobuf//:protobuf_headers", + "@protobuf_archive//:protobuf_headers", clean_dep("//third_party/eigen3"), clean_dep("//tensorflow/core:framework_headers_lib"), ] @@ -1021,9 +1038,9 @@ def tf_py_wrap_cc(name, native.cc_binary( name=cc_library_name, srcs=[module_name + ".cc"], - copts=(copts + [ + copts=(copts + if_not_windows([ "-Wno-self-assign", "-Wno-sign-compare", "-Wno-write-strings" - ] + tf_extension_copts()), + ]) + tf_extension_copts()), linkopts=tf_extension_linkopts() + extra_linkopts, linkstatic=1, linkshared=1, diff --git a/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt index 72cc53244768ad515c0ce33b937a2eae3a9fd98a..a095616c00cfe8fb64413e2078ae1589a423d2f4 100644 --- a/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt @@ -55,6 +55,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt index 5c77b3dd5cca6c7741764e6b4bcea82ef30a47fd..260c796fd65b90020eb2b8191645ffdb2402a4a4 100644 --- a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt @@ -13,7 +13,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\', \'encoding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "num_records_produced" diff --git a/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt index f5b0bae58d0d11d1fb0b83e3996a038f6254ccdc..0a3b81bf829f48e88e9c48ce26cdbb4207101a16 100644 --- a/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt @@ -34,7 +34,7 @@ tf_class { } member_method { name: "make_callable" - argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'fetches\', \'feed_list\', \'accept_options\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " } member_method { name: "partial_run" diff --git a/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f9b7e9bbca82858ca99e67d70cf93583ca75972f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.LMDBReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt index 1bfe723ce754830efeebd7644871ff29f9809423..8fed133561544b91abfc64577e63a7088b43a007 100644 --- a/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt @@ -55,6 +55,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt index dbe25f3a5b9ecc1596c77862396c684b6ddb9c5f..ebb017e81bc29e062d804fbe9f50c62f7b615dab 100644 --- a/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt @@ -55,6 +55,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt b/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt index 9263d73a51161e9df083992528400b57302832d2..761f90989f316611d42580ee911e24bb3d0d2fec 100644 --- a/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt @@ -54,6 +54,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt index ec783ffe5a01d66965d6370ec1bc6c83178b5a8c..f3ca84139311bc05478e3dce876b53f7b9dec883 100644 --- a/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt @@ -55,6 +55,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.-session.pbtxt index 173cd1963e5e8c088556e8530b65ac1bdee99dc3..1d6b037f9c3540653a8fb18b6508f74b01da66ab 100644 --- a/tensorflow/tools/api/golden/tensorflow.-session.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-session.pbtxt @@ -34,7 +34,7 @@ tf_class { } member_method { name: "make_callable" - argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'fetches\', \'feed_list\', \'accept_options\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " } member_method { name: "partial_run" diff --git a/tensorflow/tools/api/golden/tensorflow.-auto-parallel-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt similarity index 87% rename from tensorflow/tools/api/golden/tensorflow.-auto-parallel-options.pbtxt rename to tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt index c8f3e8fb154c5a1a2bb61759d9241d7e79fe884e..067f02ce8cbb1a1f6e65758f37bb1d36927fad98 100644 --- a/tensorflow/tools/api/golden/tensorflow.-auto-parallel-options.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt @@ -1,21 +1,21 @@ -path: "tensorflow.AutoParallelOptions" +path: "tensorflow.SummaryMetadata.PluginData" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { - name: "DESCRIPTOR" - mtype: "" + name: "CONTENT_FIELD_NUMBER" + mtype: "" } member { - name: "ENABLE_FIELD_NUMBER" - mtype: "" + name: "DESCRIPTOR" + mtype: "" } member { name: "Extensions" mtype: "" } member { - name: "NUM_REPLICAS_FIELD_NUMBER" + name: "PLUGIN_NAME_FIELD_NUMBER" mtype: "" } member_method { diff --git a/tensorflow/tools/api/golden/tensorflow.-rewriter-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt similarity index 65% rename from tensorflow/tools/api/golden/tensorflow.-rewriter-config.pbtxt rename to tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt index 34d2e1761280de8079f82bef02b7dc2cc5ace442..b9156521ccbee25486113a82ddec1053f8b32e3b 100644 --- a/tensorflow/tools/api/golden/tensorflow.-rewriter-config.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt @@ -1,21 +1,13 @@ -path: "tensorflow.RewriterConfig" +path: "tensorflow.SummaryMetadata" tf_class { - is_instance: "" + is_instance: "" is_instance: "" - member { - name: "AUTO_PARALLEL_FIELD_NUMBER" - mtype: "" - } - member { - name: "CONSTANT_FOLDING_FIELD_NUMBER" - mtype: "" - } member { name: "DESCRIPTOR" mtype: "" } member { - name: "DISABLE_MODEL_PRUNING_FIELD_NUMBER" + name: "DISPLAY_NAME_FIELD_NUMBER" mtype: "" } member { @@ -23,27 +15,15 @@ tf_class { mtype: "" } member { - name: "MANUAL" + name: "PLUGIN_DATA_FIELD_NUMBER" mtype: "" } member { - name: "MEMORY_OPTIMIZATION_FIELD_NUMBER" - mtype: "" - } - member { - name: "MemOptType" - mtype: "" - } - member { - name: "NO_MEM_OPT" - mtype: "" - } - member { - name: "OPTIMIZERS_FIELD_NUMBER" - mtype: "" + name: "PluginData" + mtype: "" } member { - name: "OPTIMIZE_TENSOR_LAYOUT_FIELD_NUMBER" + name: "SUMMARY_DESCRIPTION_FIELD_NUMBER" mtype: "" } member_method { diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt index d5b9cb8f5ed3cf088f5bd27809ff98f00801217d..8e3598fb2470b327e6e3601969f055d4907f614a 100644 --- a/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt @@ -54,6 +54,10 @@ tf_class { name: "merge_with" argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "most_specific_compatible_shape" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "num_elements" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt b/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1e4d333cc0bb0bb33fb4cc8d76badd30c8babaa4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.bitwise" +tf_module { + member_method { + name: "bitwise_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bitwise_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bitwise_xor" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "invert" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-bernoulli.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-bernoulli.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cfe09345acccc410ad3041a965901134440e3c77 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-bernoulli.pbtxt @@ -0,0 +1,135 @@ +path: "tensorflow.distributions.Bernoulli" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "logits" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "probs" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'logits\', \'probs\', \'dtype\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'False\', \'True\', \'Bernoulli\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-beta.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-beta.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2e6578bae1604f69e4697bb4668dd69d94bd68b5 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-beta.pbtxt @@ -0,0 +1,139 @@ +path: "tensorflow.distributions.Beta" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration0" + mtype: "" + } + member { + name: "concentration1" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "total_concentration" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'concentration1\', \'concentration0\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'True\', \'Beta\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-categorical.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-categorical.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d42b0e82e4fab3e30d3ebf1b8bea8b44bb61ea0f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-categorical.pbtxt @@ -0,0 +1,139 @@ +path: "tensorflow.distributions.Categorical" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "event_size" + mtype: "" + } + member { + name: "logits" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "probs" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'logits\', \'probs\', \'dtype\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'False\', \'True\', \'Categorical\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet-multinomial.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet-multinomial.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..710164743e851f0bb5c31ebe78b260b623e87378 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet-multinomial.pbtxt @@ -0,0 +1,139 @@ +path: "tensorflow.distributions.DirichletMultinomial" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "total_concentration" + mtype: "" + } + member { + name: "total_count" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'total_count\', \'concentration\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'DirichletMultinomial\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6cc361672ed8da313e1bebc41fbf093e019d38ad --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet.pbtxt @@ -0,0 +1,135 @@ +path: "tensorflow.distributions.Dirichlet" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "total_concentration" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'concentration\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Dirichlet\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-distribution.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-distribution.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..40ad07d1be4bdea9585eb276debb1fdf3dfff583 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-distribution.pbtxt @@ -0,0 +1,126 @@ +path: "tensorflow.distributions.Distribution" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'reparameterization_type\', \'validate_args\', \'allow_nan_stats\', \'parameters\', \'graph_parents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-exponential.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-exponential.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8f34d25fea873827997ecd9df10cf1b3bfd0e56b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-exponential.pbtxt @@ -0,0 +1,136 @@ +path: "tensorflow.distributions.Exponential" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "rate" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Exponential\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-gamma.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-gamma.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0ae88fba3b4fd176641cc17c916181cc9a6a12c6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-gamma.pbtxt @@ -0,0 +1,135 @@ +path: "tensorflow.distributions.Gamma" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "rate" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'concentration\', \'rate\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Gamma\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-laplace.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-laplace.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e7cd595e946cb91f162a2a1af8753e44cdfbc0e1 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-laplace.pbtxt @@ -0,0 +1,135 @@ +path: "tensorflow.distributions.Laplace" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "loc" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "scale" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loc\', \'scale\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Laplace\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-multinomial.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-multinomial.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7a4a16ff836a485e65cb6e061e27b92907cb4a63 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-multinomial.pbtxt @@ -0,0 +1,139 @@ +path: "tensorflow.distributions.Multinomial" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "logits" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "probs" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "total_count" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'total_count\', \'logits\', \'probs\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'True\', \'Multinomial\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..14c8c34cc2d8efacec706bdb894d9f069d5e7033 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-normal.pbtxt @@ -0,0 +1,135 @@ +path: "tensorflow.distributions.Normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "loc" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "scale" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loc\', \'scale\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Normal\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-register-k-l.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-register-k-l.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e3db443c2bdaa70f7651126a30caf2062a3c6f67 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-register-k-l.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distributions.RegisterKL" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dist_cls_a\', \'dist_cls_b\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-reparameterization-type.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-reparameterization-type.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..02e8d576ddd00aa21005fa39cd323a92392bf75a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-reparameterization-type.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distributions.ReparameterizationType" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rep_type\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-student-t.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-student-t.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..30db6d3f35c1c8ea7bbc376a20093302dd373bd9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-student-t.pbtxt @@ -0,0 +1,139 @@ +path: "tensorflow.distributions.StudentT" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "df" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "loc" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "scale" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'df\', \'loc\', \'scale\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'StudentT\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-uniform.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.-uniform.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..46cbdf225f68e879fd18ef4a07048746a9a71b08 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.-uniform.pbtxt @@ -0,0 +1,139 @@ +path: "tensorflow.distributions.Uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "high" + mtype: "" + } + member { + name: "low" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'low\', \'high\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'False\', \'True\', \'Uniform\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "range" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..11565bd3e4178202fa82e2e079d1035190dbd6ec --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt @@ -0,0 +1,65 @@ +path: "tensorflow.distributions.bijectors.Bijector" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dtype" + mtype: "" + } + member { + name: "event_ndims" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_constant_jacobian" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'event_ndims\', \'graph_parents\', \'is_constant_jacobian\', \'validate_args\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "forward" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward\'], " + } + member_method { + name: "forward_event_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "forward_event_shape_tensor" + argspec: "args=[\'self\', \'input_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_event_shape_tensor\'], " + } + member_method { + name: "forward_log_det_jacobian" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_log_det_jacobian\'], " + } + member_method { + name: "inverse" + argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], " + } + member_method { + name: "inverse_event_shape" + argspec: "args=[\'self\', \'output_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "inverse_event_shape_tensor" + argspec: "args=[\'self\', \'output_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_event_shape_tensor\'], " + } + member_method { + name: "inverse_log_det_jacobian" + argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_log_det_jacobian\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1e5fe624eb838e188594d03b656c12890db344a1 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt @@ -0,0 +1,66 @@ +path: "tensorflow.distributions.bijectors.Identity" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtype" + mtype: "" + } + member { + name: "event_ndims" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_constant_jacobian" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'validate_args\', \'event_ndims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'0\', \'identity\'], " + } + member_method { + name: "forward" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward\'], " + } + member_method { + name: "forward_event_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "forward_event_shape_tensor" + argspec: "args=[\'self\', \'input_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_event_shape_tensor\'], " + } + member_method { + name: "forward_log_det_jacobian" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_log_det_jacobian\'], " + } + member_method { + name: "inverse" + argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], " + } + member_method { + name: "inverse_event_shape" + argspec: "args=[\'self\', \'output_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "inverse_event_shape_tensor" + argspec: "args=[\'self\', \'output_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_event_shape_tensor\'], " + } + member_method { + name: "inverse_log_det_jacobian" + argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_log_det_jacobian\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1d0144f36ec332740889dc8caa5add8f41960d92 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.distributions.bijectors" +tf_module { + member { + name: "Bijector" + mtype: "" + } + member { + name: "Identity" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2fba7c506ed9d2490e7c19c1746d3f4e9645424f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.distributions.pbtxt @@ -0,0 +1,79 @@ +path: "tensorflow.distributions" +tf_module { + member { + name: "Bernoulli" + mtype: "" + } + member { + name: "Beta" + mtype: "" + } + member { + name: "Categorical" + mtype: "" + } + member { + name: "Dirichlet" + mtype: "" + } + member { + name: "DirichletMultinomial" + mtype: "" + } + member { + name: "Distribution" + mtype: "" + } + member { + name: "Exponential" + mtype: "" + } + member { + name: "FULLY_REPARAMETERIZED" + mtype: "" + } + member { + name: "Gamma" + mtype: "" + } + member { + name: "Laplace" + mtype: "" + } + member { + name: "Multinomial" + mtype: "" + } + member { + name: "NOT_REPARAMETERIZED" + mtype: "" + } + member { + name: "Normal" + mtype: "" + } + member { + name: "RegisterKL" + mtype: "" + } + member { + name: "ReparameterizationType" + mtype: "" + } + member { + name: "StudentT" + mtype: "" + } + member { + name: "Uniform" + mtype: "" + } + member { + name: "bijectors" + mtype: "" + } + member_method { + name: "kl_divergence" + argspec: "args=[\'distribution_a\', \'distribution_b\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3a6f770153013dc925dc1b65a38ec59202c4b0b2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.DNNClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..83e53d3960477b8170664c03ee30f588f87454b9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.DNNLinearCombinedClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..17f30a04fbfe7ffe464e7d107f8a9d9a27140188 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.DNNLinearCombinedRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..edd68f0bb9ac8654dbc53e090d812de37a168515 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.DNNRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt index 5dbfe2172640916803204a4c8f2c5e250bc982d7..6608d21d44c219acbf0265bee368a5a007eebc92 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "eval_metric_ops" mtype: "" } + member { + name: "evaluation_hooks" + mtype: "" + } member { name: "export_outputs" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3715dd5ec76284004efb24b0b6316d1eec87a589 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.LinearClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ccb4abf675f3c05a14990a5ae0da3068fc0d8a47 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.LinearRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt index d69c475a313075a5b165dba9a80e30cf8212657d..801260c4507803345c4c84852fd83832b752ac12 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt @@ -22,6 +22,10 @@ tf_class { name: "keep_checkpoint_max" mtype: "" } + member { + name: "log_step_count_steps" + mtype: "" + } member { name: "master" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt index 0d5dc73271dbc972c9177a6274f1632862f93ef0..07b04810b5c6d2eda3c3dce5ad4c35592158b085 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt @@ -1,5 +1,21 @@ path: "tensorflow.estimator" tf_module { + member { + name: "DNNClassifier" + mtype: "" + } + member { + name: "DNNLinearCombinedClassifier" + mtype: "" + } + member { + name: "DNNLinearCombinedRegressor" + mtype: "" + } + member { + name: "DNNRegressor" + mtype: "" + } member { name: "Estimator" mtype: "" @@ -8,6 +24,14 @@ tf_module { name: "EstimatorSpec" mtype: "" } + member { + name: "LinearClassifier" + mtype: "" + } + member { + name: "LinearRegressor" + mtype: "" + } member { name: "ModeKeys" mtype: "" @@ -24,4 +48,12 @@ tf_module { name: "inputs" mtype: "" } + member_method { + name: "classifier_parse_example_spec" + argspec: "args=[\'feature_columns\', \'label_key\', \'label_dtype\', \'label_default\', \'weight_column\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'None\'], " + } + member_method { + name: "regressor_parse_example_spec" + argspec: "args=[\'feature_columns\', \'label_key\', \'label_dtype\', \'label_default\', \'label_dimension\', \'weight_column\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'1\', \'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt index 4c633a850f8e069135f122292bac019e2646aa61..2a57a845cdcb92d2c3e5d87e06d4e03886696be1 100644 --- a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt @@ -18,7 +18,7 @@ tf_module { } member_method { name: "categorical_column_with_vocabulary_list" - argspec: "args=[\'key\', \'vocabulary_list\', \'dtype\', \'default_value\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " + argspec: "args=[\'key\', \'vocabulary_list\', \'dtype\', \'default_value\', \'num_oov_buckets\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\', \'0\'], " } member_method { name: "crossed_column" diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index 93257c84a1f4ecd078923a2434d4ce48355e13ab..8f7790f2996d795ab7681c93d32909e01250725c 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -96,10 +96,6 @@ tf_module { name: "non_max_suppression" argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } - member_method { - name: "non_max_suppression_v2" - argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "pad_to_bounding_box" argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt index 9f817beafd9251f1cd2a5d7a59f286d302948dc4..3beb95d25c15996b5ceb9c5005373498614bf944 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt @@ -256,6 +256,10 @@ tf_module { name: "sampled_softmax_loss" argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\'], " } + member_method { + name: "selu" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "separable_conv2d" argspec: "args=[\'input\', \'depthwise_filter\', \'pointwise_filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index b21d9a8ee3378c2a94636c69b8cbf089e8f04cad..a75e9e808025025b20b9c109e4b040c3b8f97fb7 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -54,7 +54,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'cell\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'cell\', \'residual_fn\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 91abff6e13097faf3e24b85c7bd4ab8a02a303a8..314449bb7353fcf503973aa8847ac2a5c086304b 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -8,10 +8,6 @@ tf_module { name: "AttrValue" mtype: "" } - member { - name: "AutoParallelOptions" - mtype: "" - } member { name: "COMPILER_VERSION" mtype: "" @@ -112,6 +108,10 @@ tf_module { name: "InteractiveSession" mtype: "" } + member { + name: "LMDBReader" + mtype: "" + } member { name: "LogMessage" mtype: "" @@ -168,10 +168,6 @@ tf_module { name: "RegisterGradient" mtype: "" } - member { - name: "RewriterConfig" - mtype: "" - } member { name: "RunMetadata" mtype: "" @@ -208,6 +204,10 @@ tf_module { name: "Summary" mtype: "" } + member { + name: "SummaryMetadata" + mtype: "" + } member { name: "TFRecordReader" mtype: "" @@ -260,6 +260,10 @@ tf_module { name: "bfloat16" mtype: "" } + member { + name: "bitwise" + mtype: "" + } member { name: "bool" mtype: "" @@ -284,6 +288,10 @@ tf_module { name: "contrib" mtype: "" } + member { + name: "distributions" + mtype: "" + } member { name: "double" mtype: "" @@ -380,6 +388,10 @@ tf_module { name: "orthogonal_initializer" mtype: "" } + member { + name: "profiler" + mtype: "" + } member { name: "python_io" mtype: "" @@ -476,6 +488,14 @@ tf_module { name: "user_ops" mtype: "" } + member { + name: "variance_scaling_initializer" + mtype: "" + } + member { + name: "variant" + mtype: "" + } member { name: "zeros_initializer" mtype: "" @@ -508,6 +528,10 @@ tf_module { name: "acos" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "acosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "add" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -530,19 +554,19 @@ tf_module { } member_method { name: "arg_max" - argspec: "args=[\'input\', \'dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'dimension\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "arg_min" - argspec: "args=[\'input\', \'dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'dimension\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "argmax" - argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\"], " } member_method { name: "argmin" - argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\"], " } member_method { name: "as_dtype" @@ -556,6 +580,10 @@ tf_module { name: "asin" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "asinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "assert_equal" argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " @@ -648,6 +676,10 @@ tf_module { name: "atan2" argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "atanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "batch_to_space" argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -720,6 +752,10 @@ tf_module { name: "clip_by_value" argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "colocate_with" + argspec: "args=[\'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], " + } member_method { name: "complex" argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -974,7 +1010,7 @@ tf_module { } member_method { name: "gather" - argspec: "args=[\'params\', \'indices\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'params\', \'indices\', \'validate_indices\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0\'], " } member_method { name: "gather_nd" @@ -1032,6 +1068,14 @@ tf_module { name: "global_variables_initializer" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "glorot_normal_initializer" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "glorot_uniform_initializer" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } member_method { name: "gradients" argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\'], " @@ -1346,7 +1390,7 @@ tf_module { } member_method { name: "pad" - argspec: "args=[\'tensor\', \'paddings\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], " + argspec: "args=[\'tensor\', \'paddings\', \'mode\', \'name\', \'constant_values\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\', \'0\'], " } member_method { name: "parallel_stack" @@ -1704,6 +1748,14 @@ tf_module { name: "sparse_placeholder" argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } + member_method { + name: "sparse_reduce_max" + argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "sparse_reduce_max_sparse" + argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } member_method { name: "sparse_reduce_sum" argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " @@ -1740,6 +1792,10 @@ tf_module { name: "sparse_segment_sum" argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "sparse_slice" + argspec: "args=[\'sp_input\', \'start\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "sparse_softmax" argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bd5c36f390add9cfb31642b80a792d65d59bb3e8 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.profiler.AdviceProto.Checker" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "REPORTS_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7c8c68e155c99da4f0c1c1ba2c944719c42c12c7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.profiler.AdviceProto.CheckersEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1b789f4fc92ed63fc72f3ecfe6be80a99eb3427f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt @@ -0,0 +1,88 @@ +path: "tensorflow.profiler.AdviceProto" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "CHECKERS_FIELD_NUMBER" + mtype: "" + } + member { + name: "Checker" + mtype: "" + } + member { + name: "CheckersEntry" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f0b9605bee1c7cf2f0154f65c475aac49c411f76 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.profiler.GraphNodeProto.InputShapesEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f3bb71354e52fa79516696e6d5a58efeb2a46c18 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt @@ -0,0 +1,164 @@ +path: "tensorflow.profiler.GraphNodeProto" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ACCELERATOR_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "CHILDREN_FIELD_NUMBER" + mtype: "" + } + member { + name: "CPU_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "DEVICES_FIELD_NUMBER" + mtype: "" + } + member { + name: "EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FLOAT_OPS_FIELD_NUMBER" + mtype: "" + } + member { + name: "INPUT_SHAPES_FIELD_NUMBER" + mtype: "" + } + member { + name: "InputShapesEntry" + mtype: "" + } + member { + name: "NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "PARAMETERS_FIELD_NUMBER" + mtype: "" + } + member { + name: "REQUESTED_BYTES_FIELD_NUMBER" + mtype: "" + } + member { + name: "RUN_COUNT_FIELD_NUMBER" + mtype: "" + } + member { + name: "SHAPES_FIELD_NUMBER" + mtype: "" + } + member { + name: "TENSOR_VALUE_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_ACCELERATOR_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_CPU_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_DEFINITION_COUNT_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_FLOAT_OPS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_PARAMETERS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_REQUESTED_BYTES_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_RUN_COUNT_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9b88a11b2c3aabbb6f1e2dc401627cb49eeff7e4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt @@ -0,0 +1,136 @@ +path: "tensorflow.profiler.MultiGraphNodeProto" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ACCELERATOR_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "CHILDREN_FIELD_NUMBER" + mtype: "" + } + member { + name: "CPU_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FLOAT_OPS_FIELD_NUMBER" + mtype: "" + } + member { + name: "GRAPH_NODES_FIELD_NUMBER" + mtype: "" + } + member { + name: "NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "PARAMETERS_FIELD_NUMBER" + mtype: "" + } + member { + name: "REQUESTED_BYTES_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_ACCELERATOR_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_CPU_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_EXEC_MICROS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_FLOAT_OPS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_PARAMETERS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TOTAL_REQUESTED_BYTES_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c5f9c78c9e85ac4265125790b3f8b29fd0fc6b12 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.profiler.OpLogProto" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "LOG_ENTRIES_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-profile-option-builder.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-profile-option-builder.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..347187a890208eb5b78bb0d1a7040efbdeb3bd3f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-profile-option-builder.pbtxt @@ -0,0 +1,89 @@ +path: "tensorflow.profiler.ProfileOptionBuilder" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "account_displayed_op_only" + argspec: "args=[\'self\', \'is_true\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "float_operation" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "order_by" + argspec: "args=[\'self\', \'attribute\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "select" + argspec: "args=[\'self\', \'attributes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "time_and_memory" + argspec: "args=[\'min_micros\', \'min_bytes\'], varargs=None, keywords=None, defaults=[\'1\', \'1\'], " + } + member_method { + name: "trainable_variables_parameter" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_accounted_types" + argspec: "args=[\'self\', \'account_type_regexes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_empty_output" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_file_output" + argspec: "args=[\'self\', \'outfile\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_max_depth" + argspec: "args=[\'self\', \'max_depth\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_min_execution_time" + argspec: "args=[\'self\', \'min_micros\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_min_float_operations" + argspec: "args=[\'self\', \'min_float_ops\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_min_memory" + argspec: "args=[\'self\', \'min_bytes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_min_occurrence" + argspec: "args=[\'self\', \'min_occurrence\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_min_parameters" + argspec: "args=[\'self\', \'min_params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_node_names" + argspec: "args=[\'self\', \'start_name_regexes\', \'show_name_regexes\', \'hide_name_regexes\', \'trim_name_regexes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "with_stdout_output" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_step" + argspec: "args=[\'self\', \'step\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_timeline_output" + argspec: "args=[\'self\', \'timeline_file\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-profiler.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-profiler.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0fb363aca48031e13487d716a0375973f93b3dc8 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-profiler.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.profiler.Profiler" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'graph\', \'op_log\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_step" + argspec: "args=[\'self\', \'step\', \'run_meta\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "advise" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_graph" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_name_scope" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_operations" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_python" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..26b25ee3d47241dbf351018f2aacbda12ff33492 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.profiler" +tf_module { + member { + name: "AdviceProto" + mtype: "" + } + member { + name: "GraphNodeProto" + mtype: "" + } + member { + name: "MultiGraphNodeProto" + mtype: "" + } + member { + name: "OpLogProto" + mtype: "" + } + member { + name: "ProfileOptionBuilder" + mtype: "" + } + member { + name: "Profiler" + mtype: "" + } + member_method { + name: "advise" + argspec: "args=[\'graph\', \'run_meta\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'0\'], " + } + member_method { + name: "profile" + argspec: "args=[\'graph\', \'run_meta\', \'op_log\', \'cmd\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'scope\', \'0\'], " + } + member_method { + name: "write_op_log" + argspec: "args=[\'graph\', \'log_dir\', \'op_log\', \'run_meta\', \'add_trace\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt index af0c11ca14d4f38547a49ac511ee13e15847eb33..31775de2d12bcd2f214f5a04be7a92f49c594fde 100644 --- a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "close" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "write" argspec: "args=[\'self\', \'record\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt index 7c24b7ad3cf38cfd949959d078e5d70838d0b2d9..35e49ee9f4a6ee5b4da2b034ece1c1e3b2136254 100644 --- a/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.saved_model.tag_constants" tf_module { + member { + name: "GPU" + mtype: "" + } member { name: "SERVING" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt index bc150e56a36ca22479cdd6a0563466ef6275e143..d95c94668250e1de236462ccdcb134245eebf092 100644 --- a/tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt @@ -4,4 +4,8 @@ tf_module { name: "build_tensor_info" argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_tensor_from_tensor_info" + argspec: "args=[\'tensor_info\', \'graph\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.pbtxt index 19d822e61bffb3621b966147519c90d425521e87..326e077d396bc5e3463bba3818f4757127ee0370 100644 --- a/tensorflow/tools/api/golden/tensorflow.summary.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.summary.pbtxt @@ -58,7 +58,7 @@ tf_module { } member_method { name: "tensor_summary" - argspec: "args=[\'name\', \'tensor\', \'summary_description\', \'collections\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'name\', \'tensor\', \'summary_description\', \'collections\', \'summary_metadata\', \'family\', \'display_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "text" diff --git a/tensorflow/tools/api/golden/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/tensorflow.test.pbtxt index 2a88f26ed02c7e2690c37180f76b965d7ffa87e0..6237207821ab18c8eb3e6148875e29e2e2fad773 100644 --- a/tensorflow/tools/api/golden/tensorflow.test.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.test.pbtxt @@ -30,7 +30,7 @@ tf_module { } member_method { name: "create_local_cluster" - argspec: "args=[\'num_workers\', \'num_ps\', \'protocol\'], varargs=None, keywords=None, defaults=[\'grpc\'], " + argspec: "args=[\'num_workers\', \'num_ps\', \'protocol\', \'worker_config\', \'ps_config\'], varargs=None, keywords=None, defaults=[\'grpc\', \'None\', \'None\'], " } member_method { name: "get_temp_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt index 2dc11df57b60b15a797b1866743b27ea1068624e..5cff6087ef533f6674d6d7f1e0a8be425c16f2ad 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\'], " } member_method { name: "apply_gradients" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt index 62bfdab40bb83c634e101388ecb69da1233c60f9..7caf837cc385dbd64611a58de2c25d4de221a911 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt @@ -9,7 +9,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'hooks\', \'scaffold\', \'master\', \'config\', \'checkpoint_dir\', \'stop_grace_period_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\', \'None\', \'None\', \'120\'], " + argspec: "args=[\'self\', \'hooks\', \'scaffold\', \'master\', \'config\', \'checkpoint_dir\', \'stop_grace_period_secs\', \'checkpoint_filename_with_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\', \'None\', \'None\', \'120\', \'None\'], " } member_method { name: "close" diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index 58fd5760c11d29f063c0f7f66ea0a11d39a08a1e..89c299ae994bcd4f6ceb6daa632f985247d3db7f 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -230,7 +230,7 @@ tf_module { } member_method { name: "MonitoredTrainingSession" - argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'600\', \'100\', \'None\', \'None\', \'120\', \'100\'], " + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'600\', \'\', \'\', \'None\', \'120\', \'100\'], " } member_method { name: "NewCheckpointReader" @@ -304,6 +304,10 @@ tf_module { name: "import_meta_graph" argspec: "args=[\'meta_graph_or_file\', \'clear_devices\', \'import_scope\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\'], " } + member_method { + name: "init_from_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "input_producer" argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], " @@ -320,6 +324,18 @@ tf_module { name: "limit_epochs" argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } + member_method { + name: "list_variables" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_variable" + argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "match_filenames_once" argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a58398d645e8397dc8e61a6e0241710c3e34218f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.variance_scaling_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/lib/BUILD b/tensorflow/tools/api/lib/BUILD index cdfa0e7be524e3bb4ec039ac19bea72747afb58c..2d3b838957d60ffb5e827c6b43100d217cc5739e 100644 --- a/tensorflow/tools/api/lib/BUILD +++ b/tensorflow/tools/api/lib/BUILD @@ -22,7 +22,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":api_objects_proto_py", - "//tensorflow/tools/common:traverse", + "//tensorflow/python:platform", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 8421d8fce28611f6049847f6fbca5538475b59af..e9aeeb385586e3abd129d9a475d89545efaca45b 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -23,11 +23,12 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/tools/api/lib:python_object_to_proto_visitor", "//tensorflow/tools/common:public_api", "//tensorflow/tools/common:traverse", - "@protobuf//:protobuf_python", ], ) diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index dfad11adf0b971748cbc64f9b86fd6cb2c7cdd37..38c1bd3fb59da57984052f504684ce4102ee76d9 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/algorithm.h" @@ -137,8 +138,21 @@ Status GetOutputShapes(const std::vector& inputs, std::vector > input_tensors; CreateTensorsFromInputInfo(inputs, &input_tensors); std::vector output_tensors; - std::vector output_tensor_names(wanted_shapes.begin(), - wanted_shapes.end()); + std::vector output_tensor_names; + for (const string& wanted_shape : wanted_shapes) { + bool is_input = false; + for (const std::pair& input_tensor : + input_tensors) { + if (input_tensor.first == wanted_shape) { + (*node_shapes)[wanted_shape] = input_tensor.second.shape(); + is_input = true; + break; + } + } + if (!is_input) { + output_tensor_names.push_back(wanted_shape); + } + } TF_RETURN_IF_ERROR( session->Run(input_tensors, output_tensor_names, {}, &output_tensors)); CHECK_EQ(output_tensors.size(), output_tensor_names.size()); @@ -155,7 +169,8 @@ Status CalculateFlops(const GraphDef& graph, Session* session, int64* total_flops, std::unordered_map* flops_by_op) { std::unordered_set floppable_ops = { - "Conv2D", "MatMul", "QuantizedConv2D", "QuantizedMatMul"}; + "Conv2D", "MatMul", "QuantizedConv2D", "QuantizedMatMul", + "DepthwiseConv2dNative"}; std::set wanted_shapes; for (const NodeDef& node : graph.node()) { @@ -200,6 +215,13 @@ Status CalculateFlops(const GraphDef& graph, } int64 output_count = output_shape.num_elements(); current_flops = k * output_count * 2; + } else if (node.op() == "DepthwiseConv2dNative") { + const TensorShape& filter_shape = found_shapes[node.input(1)]; + const TensorShape& output_shape = found_shapes[node.name()]; + int64 filter_height = filter_shape.dim_size(0); + int64 filter_width = filter_shape.dim_size(1); + int64 output_count = output_shape.num_elements(); + current_flops = output_count * filter_height * filter_width * 2; } (*flops_by_op)[node.op()] += current_flops; *total_flops += current_flops; diff --git a/tensorflow/tools/ci_build/Dockerfile.tensorboard b/tensorflow/tools/ci_build/Dockerfile.tensorboard deleted file mode 100644 index 9795872e2c4907908c288f8901d0a007f8d1dcaa..0000000000000000000000000000000000000000 --- a/tensorflow/tools/ci_build/Dockerfile.tensorboard +++ /dev/null @@ -1,11 +0,0 @@ -FROM ubuntu:14.04 - -MAINTAINER Jan Prach - -# Copy and run the install scripts. -COPY install/*.sh /install/ -RUN /install/install_bootstrap_deb_packages.sh -RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:george-edison55/cmake-3.x -RUN /install/install_deb_packages.sh -RUN /install/install_tensorboard_packages.sh diff --git a/tensorflow/tools/ci_build/builds/builds_common.sh b/tensorflow/tools/ci_build/builds/builds_common.sh index fd9a14bd698d183f14a65079d043e839319a435c..e3b58d038a713b0f8171b3e1803e4329cceda7c9 100644 --- a/tensorflow/tools/ci_build/builds/builds_common.sh +++ b/tensorflow/tools/ci_build/builds/builds_common.sh @@ -16,6 +16,10 @@ # # Common Bash functions used by build scripts +COLOR_NC='\033[0m' +COLOR_LIGHT_GRAY='\033[0;37m' +COLOR_GREEN='\033[0;32m' +COLOR_RED='\033[0;31m' die() { # Print a message and exit with code 1. diff --git a/tensorflow/tools/ci_build/builds/configured b/tensorflow/tools/ci_build/builds/configured index 25cb51ea7ccfb300d064f9a1a313bed57212832b..563e07e3afb2544d9dfde777860c9f3919a8d2ee 100755 --- a/tensorflow/tools/ci_build/builds/configured +++ b/tensorflow/tools/ci_build/builds/configured @@ -56,7 +56,7 @@ else fi pushd "${CI_TENSORFLOW_SUBMODULE_PATH:-.}" -yes "" | ./configure +$PYTHON_BIN_PATH configure.py popd # Gather and print build information diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh index 85c712d3c6db353574fda40363d58dc328259430..112dab3a7332bbe6446517843ecfe7ff9a526d0f 100755 --- a/tensorflow/tools/ci_build/builds/pip.sh +++ b/tensorflow/tools/ci_build/builds/pip.sh @@ -23,7 +23,7 @@ # # When executing the Python unit tests, the script obeys the shell # variables: TF_BUILD_BAZEL_CLEAN, TF_BUILD_INSTALL_EXTRA_PIP_PACKAGES, -# NO_TEST_ON_INSTALL +# NO_TEST_ON_INSTALL, PIP_TEST_ROOT # # TF_BUILD_BAZEL_CLEAN, if set to any non-empty and non-0 value, directs the # script to perform bazel clean prior to main build and test steps. @@ -41,6 +41,9 @@ # If NO_TEST_TFDBG_BINARIES has any non-empty and non-0 value, the testing of # TensorFlow Debugger (tfdbg) binaries and examples will be skipped. # +# If PIP_TEST_ROOT has a non-empty and a non-0 value, the whl files will be +# placed in that directory. +# # Any flags not listed in the usage above will be passed directly to Bazel. # # If the --test_tutorials flag is set, it will cause the script to run the @@ -70,6 +73,9 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/builds_common.sh" +SKIP_RETURN_CODE=112 + + # Get the command line arguments CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' ) shift @@ -162,7 +168,10 @@ echo "Python binary path to be used in PIP install: ${PYTHON_BIN_PATH} "\ "(Major.Minor version: ${PY_MAJOR_MINOR_VER})" # Build PIP Wheel file -PIP_TEST_ROOT="pip_test" +# Set default pip file folder unless specified by env variable +if [ -z "$PIP_TEST_ROOT" ]; then + PIP_TEST_ROOT="pip_test" +fi PIP_WHL_DIR="${PIP_TEST_ROOT}/whl" PIP_WHL_DIR=$(realpath ${PIP_WHL_DIR}) # Get absolute path rm -rf ${PIP_WHL_DIR} && mkdir -p ${PIP_WHL_DIR} @@ -236,106 +245,301 @@ if [[ $(uname) == "Linux" ]]; then fi fi -# Perform installation -echo "Installing pip whl file: ${WHL_PATH}" -# Create virtualenv directory for install test -VENV_DIR="${PIP_TEST_ROOT}/venv" +create_activate_virtualenv_and_install_tensorflow() { + # Create and activate a virtualenv; then install tensorflow pip package in it. + # + # Usage: + # create_activate_virtualenv_and_install_tensorflow [--clean] \ + # + # + # Arguments: + # --clean: Create a clean virtualenv, i.e., without --system-site-packages. + # VIRTUALENV_DIR: virtualenv directory to be created. + # TF_WHEEL_PATH: Path to the tensorflow wheel file to be installed in the + # virtualenv. + + VIRTUALENV_FLAGS="--system-site-packages" + if [[ "$1" == "--clean" ]]; then + VIRTUALENV_FLAGS="" + shift + fi + + VIRTUALENV_DIR="$1" + TF_WHEEL_PATH="$2" + if [[ -d "${VIRTUALENV_DIR}" ]]; then + if rm -rf "${VIRTUALENV_DIR}" + then + echo "Removed existing virtualenv directory: ${VIRTUALENV_DIR}" + else + die "Failed to remove existing virtualenv directory: ${VIRTUALENV_DIR}" + fi + fi -if [[ -d "${VENV_DIR}" ]]; then - if rm -rf "${VENV_DIR}" + if mkdir -p "${VIRTUALENV_DIR}" then - echo "Removed existing virtualenv directory: ${VENV_DIR}" + echo "Created virtualenv directory: ${VIRTUALENV_DIR}" else - die "Failed to remove existing virtualenv directory: ${VENV_DIR}" + die "FAILED to create virtualenv directory: ${VIRTUALENV_DIR}" fi -fi -if mkdir -p ${VENV_DIR} -then - echo "Created virtualenv directory: ${VENV_DIR}" -else - die "FAILED to create virtualenv directory: ${VENV_DIR}" -fi - -# Verify that virtualenv exists -if [[ -z $(which virtualenv) ]]; then - die "FAILED: virtualenv not available on path" -fi + # Verify that virtualenv exists + if [[ -z $(which virtualenv) ]]; then + die "FAILED: virtualenv not available on path" + fi -virtualenv --system-site-packages -p "${PYTHON_BIN_PATH}" "${VENV_DIR}" || \ + virtualenv ${VIRTUALENV_FLAGS} \ + -p "${PYTHON_BIN_PATH}" "${VIRTUALENV_DIR}" || \ die "FAILED: Unable to create virtualenv" -source "${VENV_DIR}/bin/activate" || \ - die "FAILED: Unable to activate virtualenv" - + source "${VIRTUALENV_DIR}/bin/activate" || \ + die "FAILED: Unable to activate virtualenv in ${VIRTUALENV_DIR}" -# Install the pip file in virtual env (plus missing dependencies) + # Install the pip file in virtual env. -# Upgrade pip so it supports tags such as cp27mu, manylinux1 etc. -echo "Upgrade pip in virtualenv" -pip install --upgrade pip==8.1.2 + # Upgrade pip so it supports tags such as cp27mu, manylinux1 etc. + echo "Upgrade pip in virtualenv" + pip install --upgrade pip==8.1.2 -# Force tensorflow reinstallation. Otherwise it may not get installed from -# last build if it had the same version number as previous build. -PIP_FLAGS="--upgrade --force-reinstall" -pip install -v ${PIP_FLAGS} ${WHL_PATH} || \ + # Force tensorflow reinstallation. Otherwise it may not get installed from + # last build if it had the same version number as previous build. + PIP_FLAGS="--upgrade --force-reinstall" + pip install -v ${PIP_FLAGS} ${WHL_PATH} || \ die "pip install (forcing to reinstall tensorflow) FAILED" -echo "Successfully installed pip package ${WHL_PATH}" - -# Install extra pip packages required by the test-on-install -for PACKAGE in ${INSTALL_EXTRA_PIP_PACKAGES}; do - echo "Installing extra pip package required by test-on-install: ${PACKAGE}" - - pip install ${PACKAGE} || \ - die "pip install ${PACKAGE} FAILED" -done - -if [[ -n "${NO_TEST_ON_INSTALL}" ]] && - [[ "${NO_TEST_ON_INSTALL}" != "0" ]]; then - echo "NO_TEST_ON_INSTALL=${NO_TEST_ON_INSTALL}:" - echo " Skipping ALL Python unit tests on install" -else - # Call run_pip_tests.sh to perform test-on-install - "${SCRIPT_DIR}/run_pip_tests.sh" --virtualenv ${GPU_FLAG} ${MAC_FLAG} || - die "PIP tests-on-install FAILED" -fi + echo "Successfully installed pip package ${TF_WHEEL_PATH}" +} -# Test user ops -if [[ "${DO_TEST_USER_OPS}" == "1" ]]; then - "${SCRIPT_DIR}/test_user_ops.sh" --virtualenv ${GPU_FLAG} || \ - die "PIP user-op tests-on-install FAILED" -fi +################################################################################ +# Smoke test of tensorflow install in clean virtualenv +################################################################################ +do_clean_virtualenv_smoke_test() { + if [[ -n "${NO_TEST_ON_INSTALL}" ]] && + [[ "${NO_TEST_ON_INSTALL}" != "0" ]]; then + echo "NO_TEST_ON_INSTALL=${NO_TEST_ON_INSTALL}:" + echo " Skipping smoke test of tensorflow install in clean virtualenv" + return ${SKIP_RETURN_CODE} + fi -# Test TensorFlow Debugger (tfdbg) examples. -if [[ "${DO_TEST_TFDBG_BINARIES}" == "1" ]]; then - echo - echo "Testing TensorFlow Debugger (tfdbg) binaries" - echo + CLEAN_VENV_DIR="${PIP_TEST_ROOT}/venv_clean" + create_activate_virtualenv_and_install_tensorflow --clean \ + "${CLEAN_VENV_DIR}" "${WHL_PATH}" # cd to a temporary directory to avoid picking up Python files in the source # tree. TMP_DIR=$(mktemp -d) pushd "${TMP_DIR}" + if [[ $(python -c "import tensorflow as tf; print(tf.Session().run(tf.constant(42)))") == 42 ]]; + then + echo "Smoke test of tensorflow install in clean virtualenv PASSED." + else + echo "Smoke test of tensroflow install in clean virtualenv FAILED." + return 1 + fi - "${SCRIPT_DIR}/../../../python/debug/examples/examples_test.sh" \ - --virtualenv || \ - die "PIP tests-on-install of tfdbg binaries FAILED" + deactivate + if [[ $? != 0 ]]; then + echo "FAILED: Unable to deactivate virtualenv from ${CLEAN_VENV_DIR}" + return 1 + fi popd -fi + rm -rf "${TMP_DIR}" "${CLEAN_VENV_DIR}" +} -# Optional: Run the tutorial tests -if [[ "${DO_TEST_TUTORIALS}" == "1" ]]; then - "${SCRIPT_DIR}/test_tutorials.sh" --virtualenv || \ - die "PIP tutorial tests-on-install FAILED" -fi +################################################################################ +# Perform installation of tensorflow in "non-clean" virtualenv and tests against +# the install. +################################################################################ +do_virtualenv_pip_test() { + # Create virtualenv directory for install test + VENV_DIR="${PIP_TEST_ROOT}/venv" + create_activate_virtualenv_and_install_tensorflow \ + "${CLEAN_VENV_DIR}" "${WHL_PATH}" + + # Install extra pip packages required by the test-on-install + for PACKAGE in ${INSTALL_EXTRA_PIP_PACKAGES}; do + echo "Installing extra pip package required by test-on-install: ${PACKAGE}" + + pip install ${PACKAGE} + if [[ $? != 0 ]]; then + echo "pip install ${PACKAGE} FAILED" + return 1 + fi + done -# Optional: Run integration tests -if [[ "${DO_INTEGRATION_TESTS}" == "1" ]]; then - "${SCRIPT_DIR}/integration_tests.sh" --virtualenv || \ - die "Integration tests on install FAILED" -fi + if [[ -n "${NO_TEST_ON_INSTALL}" ]] && + [[ "${NO_TEST_ON_INSTALL}" != "0" ]]; then + echo "NO_TEST_ON_INSTALL=${NO_TEST_ON_INSTALL}:" + echo " Skipping ALL Python unit tests on install" + return ${SKIP_RETURN_CODE} + else + # Call run_pip_tests.sh to perform test-on-install + "${SCRIPT_DIR}/run_pip_tests.sh" --virtualenv ${GPU_FLAG} ${MAC_FLAG} + if [[ $? != 0 ]]; then + echo "PIP tests-on-install FAILED" + return 1 + fi + fi +} -deactivate || \ - die "FAILED: Unable to deactivate virtualenv" +################################################################################ +# Run tests tagged with oss_serial against the virtualenv install. +################################################################################ +do_virtualenv_oss_serial_pip_test() { + if [[ -n "${NO_TEST_ON_INSTALL}" ]] && + [[ "${NO_TEST_ON_INSTALL}" != "0" ]]; then + echo "NO_TEST_ON_INSTALL=${NO_TEST_ON_INSTALL}:" + echo " Skipping Python unit tests on install tagged with oss_serial" + return ${SKIP_RETURN_CODE} + else + # Call run_pip_tests.sh to perform test-on-install + "${SCRIPT_DIR}/run_pip_tests.sh" \ + --virtualenv ${GPU_FLAG} ${MAC_FLAG} --oss_serial + if [[ $? != 0 ]]; then + echo "PIP tests-on-install (oss_serial) FAILED" + return 1 + fi + fi +} + +################################################################################ +# Test user ops (optional). +################################################################################ +do_test_user_ops() { + if [[ "${DO_TEST_USER_OPS}" == "1" ]]; then + "${SCRIPT_DIR}/test_user_ops.sh" --virtualenv ${GPU_FLAG} + if [[ $? != 0 ]]; then + echo "PIP user-op tests-on-install FAILED" + return 1 + fi + else + echo "Skipping user-op test-on-install due to DO_TEST_USER_OPS = ${DO_TEST_USER_OPS}" + return ${SKIP_RETURN_CODE} + fi +} + +################################################################################ +# Test TensorFlow Debugger (tfdbg) binaries (optional). +################################################################################ +do_test_tfdbg_binaries() { + if [[ "${DO_TEST_TFDBG_BINARIES}" == "1" ]]; then + # cd to a temporary directory to avoid picking up Python files in the source + # tree. + TMP_DIR=$(mktemp -d) + pushd "${TMP_DIR}" + + "${SCRIPT_DIR}/../../../python/debug/examples/examples_test.sh" \ + --virtualenv + if [[ $? != 0 ]]; then + echo "PIP tests-on-install of tfdbg binaries FAILED" + return 1 + fi + popd + else + echo "Skipping test of tfdbg binaries due to DO_TEST_TFDBG_BINARIES = ${DO_TEST_TFDBG_BINARIES}" + return ${SKIP_RETURN_CODE} + fi +} + +################################################################################ +# Test tutorials (optional). +################################################################################ +do_test_tutorials() { + if [[ "${DO_TEST_TUTORIALS}" == "1" ]]; then + "${SCRIPT_DIR}/test_tutorials.sh" --virtualenv + if [[ $? != 0 ]]; then + echo "PIP tutorial tests-on-install FAILED" + return 1 + fi + else + echo "Skipping tutorial tests-on-install due to DO_TEST_TUTORIALS = ${DO_TEST_TUTORIALS}" + return ${SKIP_RETURN_CODE} + fi +} + +################################################################################ +# Integration test for ffmpeg (optional). +################################################################################ +do_ffmpeg_integration_test() { + # Optional: Run integration tests + if [[ "${DO_INTEGRATION_TESTS}" == "1" ]]; then + "${SCRIPT_DIR}/integration_tests.sh" --virtualenv + if [[ $? != 0 ]]; then + echo "Integration tests on install FAILED" + return 1 + fi + else + echo "Skipping ffmpeg integration due to DO_INTEGRATION_TESTS = ${DO_INTEGRATION_TESTS}" + return ${SKIP_RETURN_CODE} + fi +} + + +# List of all PIP test tasks and their descriptions. +PIP_TASKS=("do_clean_virtualenv_smoke_test" "do_virtualenv_pip_test" "do_virtualenv_oss_serial_pip_test" "do_test_user_ops" "do_test_tfdbg_binaries" "do_test_tutorials" "do_ffmpeg_integration_test") +PIP_TASKS_DESC=("Smoke test of pip install in clean virtualenv" "PIP tests in virtualenv" "PIP test in virtualenv (tag: oss_serial)" "User ops test" "TensorFlow Debugger (tfdbg) binaries test" "Tutorials test" "ffmpeg integration test") + + +# Execute all the PIP test steps. +COUNTER=0 +FAIL_COUNTER=0 +PASS_COUNTER=0 +SKIP_COUNTER=0 +while [[ ${COUNTER} -lt "${#PIP_TASKS[@]}" ]]; do + INDEX=COUNTER + ((INDEX++)) + + echo "" + echo "=== PIP test step ${INDEX} of ${#PIP_TASKS[@]}: "\ +"${PIP_TASKS[COUNTER]} (${PIP_TASKS_DESC[COUNTER]}) ===" + echo "" + + ${PIP_TASKS[COUNTER]} + RESULT=$? + + if [[ ${RESULT} == ${SKIP_RETURN_CODE} ]]; then + ((SKIP_COUNTER++)) + elif [[ ${RESULT} != "0" ]]; then + ((FAIL_COUNTER++)) + else + ((PASS_COUNTER++)) + fi + + STEP_EXIT_CODES+=(${RESULT}) + + echo "" + ((COUNTER++)) +done + +deactivate || die "FAILED: Unable to deactivate virtualenv from ${VENV_DIR}" + + +# Print summary of build results +COUNTER=0 +echo "==== Summary of PIP test results ====" +while [[ ${COUNTER} -lt "${#PIP_TASKS[@]}" ]]; do + INDEX=COUNTER + ((INDEX++)) + + echo "${INDEX}. ${PIP_TASKS[COUNTER]}: ${PIP_TASKS_DESC[COUNTER]}" + if [[ ${STEP_EXIT_CODES[COUNTER]} == ${SKIP_RETURN_CODE} ]]; then + printf " ${COLOR_LIGHT_GRAY}SKIP${COLOR_NC}\n" + elif [[ ${STEP_EXIT_CODES[COUNTER]} == "0" ]]; then + printf " ${COLOR_GREEN}PASS${COLOR_NC}\n" + else + printf " ${COLOR_RED}FAIL${COLOR_NC}\n" + fi + + ((COUNTER++)) +done + +echo +echo "${SKIP_COUNTER} skipped; ${FAIL_COUNTER} failed; ${PASS_COUNTER} passed." + +echo +if [[ ${FAIL_COUNTER} == "0" ]]; then + printf "PIP test ${COLOR_GREEN}PASSED${COLOR_NC}\n" +else + printf "PIP test ${COLOR_RED}FAILED${COLOR_NC}\n" + exit 1 +fi diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh index 8e364f7ffb7c857762f30d85b2edcdb34c16c45e..9a6890401b7ab3dd54c50ddf41c539e1c6de4032 100755 --- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh +++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh @@ -18,7 +18,7 @@ # Run the python unit tests from the source code on the pip installation. # # Usage: -# run_pip_tests.sh [--virtualenv] [--gpu] [--mac] +# run_pip_tests.sh [--virtualenv] [--gpu] [--mac] [--oss_serial] # # If the flag --virtualenv is set, the script will use "python" as the Python # binary path. Otherwise, it will use tools/python_bin_path.sh to determine @@ -30,6 +30,10 @@ # The --mac flag informs the script that this is running on mac. Mac does not # have flock, so we should skip using parallel_gpu_execute on mac. # +# The --oss_serial flag lets the script run only the py tests with the +# oss_serial tag, in a serial fashion, i.e., using the bazel flag +# --local_test_jobs=1 +# # TF_BUILD_APPEND_ARGUMENTS: # Additional command line arguments for the bazel, # pip.sh or android.sh command @@ -42,6 +46,7 @@ source "${SCRIPT_DIR}/builds_common.sh" IS_VIRTUALENV=0 IS_GPU=0 IS_MAC=0 +IS_OSS_SERIAL=0 while true; do if [[ "$1" == "--virtualenv" ]]; then IS_VIRTUALENV=1 @@ -49,6 +54,8 @@ while true; do IS_GPU=1 elif [[ "$1" == "--mac" ]]; then IS_MAC=1 + elif [[ "$1" == "--oss_serial" ]]; then + IS_OSS_SERIAL=1 fi shift @@ -69,10 +76,19 @@ ln -s $(pwd)/tensorflow ${PIP_TEST_ROOT}/tensorflow # Do not run tests with "no_pip" tag. If running GPU tests, also do not run # tests with no_pip_gpu tag. -PIP_TEST_FILTER_TAG="-no_pip" +PIP_TEST_FILTER_TAG="-no_pip,-no_oss" +if [[ ${IS_OSS_SERIAL} == "1" ]]; then + PIP_TEST_FILTER_TAG="${PIP_TEST_FILTER_TAG},oss_serial" +else + PIP_TEST_FILTER_TAG="${PIP_TEST_FILTER_TAG},-oss_serial" +fi + if [[ ${IS_GPU} == "1" ]]; then PIP_TEST_FILTER_TAG="-no_pip_gpu,${PIP_TEST_FILTER_TAG}" fi +if [[ ${IS_MAC} == "1" ]]; then + PIP_TEST_FILTER_TAG="-nomac,${PIP_TEST_FILTER_TAG}" +fi # Bazel flags we need for all tests: # define=no_tensorflow_py_deps=true, to skip all test dependencies. @@ -104,7 +120,7 @@ else fi export TF_NEED_CUDA=$IS_GPU -yes "" | ./configure +${PYTHON_BIN_PATH} configure.py # Figure out how many concurrent tests we can run and do run the tests. BAZEL_PARALLEL_TEST_FLAGS="" @@ -126,6 +142,10 @@ else fi fi +if [[ ${IS_OSS_SERIAL} == 1 ]]; then + BAZEL_PARALLEL_TEST_FLAGS="--local_test_jobs=1" +fi + # Actually run the tests. bazel test ${BAZEL_FLAGS} ${BAZEL_PARALLEL_TEST_FLAGS} -- \ ${BAZEL_TEST_TARGETS} diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 1cf87d7c7c09613d2a7f265e5cc1b54a3e2ae47e..13cfaad57e76028eaa7484aea334e8ed260b83b6 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -95,6 +95,10 @@ # # This script can be used by Jenkins parameterized / matrix builds. +# TODO(jhseu): Temporary for the gRPC pull request due to the +# protobuf -> protobuf_archive rename. Remove later. +TF_BUILD_BAZEL_CLEAN=1 + # Helper function: Convert to lower case to_lower () { echo "$1" | tr '[:upper:]' '[:lower:]' @@ -358,7 +362,7 @@ if [[ "${TF_BUILD_APPEND_ARGUMENTS}" == *"--test_tag_filters="* ]]; then fi done else - EXTRA_ARGS="${TF_BUILD_APPEND_ARGUMENTS} --test_tag_filters=-benchmark-test" + EXTRA_ARGS="${TF_BUILD_APPEND_ARGUMENTS} --test_tag_filters=-no_oss,-oss_serial,-benchmark-test" if [[ ${IS_MAC} == "1" ]]; then EXTRA_ARGS="${EXTRA_ARGS},-nomac" fi diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index ddc95f690021feefe6725fada3385677aac09a98..68e826ccd5576100380c727a710482db1d5433f7 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -24,6 +24,7 @@ # Current script directory SCRIPT_DIR=$( cd ${0%/*} && pwd -P ) +source "${SCRIPT_DIR}/builds/builds_common.sh" # Helper functions die() { @@ -418,9 +419,25 @@ do_pip_smoke_test() { "The pip smoke test failed." } +do_code_link_check() { + tensorflow/tools/ci_build/code_link_check.sh +} + +do_check_load_py_test() { + BUILD_CMD="bazel build //tensorflow/tools/pip_package:check_load_py_test" + ${BUILD_CMD} + cmd_status \ + "check_load_py_test failed to build." + + BUILD_CMD="bazel-bin/tensorflow/tools/pip_package/check_load_py_test" + ${BUILD_CMD} + cmd_status \ + "check_load_py_test failed." +} + # Supply all sanity step commands and descriptions -SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test") -SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package") +SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check") +SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links") INCREMENTAL_FLAG="" @@ -463,7 +480,7 @@ while [[ ${COUNTER} -lt "${#SANITY_STEPS[@]}" ]]; do ((PASS_COUNTER++)) fi - IFS=" " read -r -a STEP_EXIT_CODES <<< "${RESULT}" + STEP_EXIT_CODES+=(${RESULT}) echo "" ((COUNTER++)) @@ -478,20 +495,21 @@ while [[ ${COUNTER} -lt "${#SANITY_STEPS[@]}" ]]; do echo "${INDEX}. ${SANITY_STEPS[COUNTER]}: ${SANITY_STEPS_DESC[COUNTER]}" if [[ ${STEP_EXIT_CODES[COUNTER]} == "0" ]]; then - echo " PASS" + printf " ${COLOR_GREEN}PASS${COLOR_NC}\n" else - echo " FAIL" + printf " ${COLOR_RED}FAIL${COLOR_NC}\n" fi ((COUNTER++)) done -echo "" +echo echo "${FAIL_COUNTER} failed; ${PASS_COUNTER} passed." -echo "" +echo if [[ ${FAIL_COUNTER} == "0" ]]; then - echo "Sanity checks PASSED" + printf "Sanity checks ${COLOR_GREEN}PASSED${COLOR_NC}\n" else - die "Sanity checks FAILED" + printf "Sanity checks ${COLOR_RED}FAILED${COLOR_NC}\n" + exit 1 fi diff --git a/tensorflow/tools/ci_build/code_link_check.sh b/tensorflow/tools/ci_build/code_link_check.sh new file mode 100755 index 0000000000000000000000000000000000000000..09130482cc9969a1c9e63fe73e183b631f53e0de --- /dev/null +++ b/tensorflow/tools/ci_build/code_link_check.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# please run this at root directory of tensorflow +success=1 + +for i in `grep -onI https://www.tensorflow.org/code/\[a-zA-Z0-9/._-\]\* -r tensorflow` +do + filename=`echo $i|awk -F: '{print $1}'` + linenumber=`echo $i|awk -F: '{print $2}'` + target=`echo $i|awk -F: '{print $4}'|tail -c +27` + + # skip files in tensorflow/models + if [[ $target == tensorflow_models/* ]] ; then + continue + fi + + if [ ! -f $target ] && [ ! -d $target ]; then + success=0 + echo Broken link $target at line $linenumber of file $filename + fi +done + +if [ $success == 0 ]; then + echo Code link check fails. + exit 1 +fi + +echo Code link check success. diff --git a/tensorflow/tools/ci_build/install/install_buildifier.sh b/tensorflow/tools/ci_build/install/install_buildifier.sh index b2dfcf8db7605a08ed9554784b8de5cecac86af7..967c62bac03a124dca885b2118e0777204afa24d 100755 --- a/tensorflow/tools/ci_build/install/install_buildifier.sh +++ b/tensorflow/tools/ci_build/install/install_buildifier.sh @@ -15,14 +15,12 @@ # ============================================================================== set -e -BUILDIFIER_DIR="buildifier" -mkdir ${BUILDIFIER_DIR} -curl -Ls https://github.com/bazelbuild/buildifier/archive/0.4.5.tar.gz | \ - tar -C "${BUILDIFIER_DIR}" --strip-components=1 -xz -pushd ${BUILDIFIER_DIR} +# Download buildifier. +wget https://github.com/bazelbuild/buildtools/releases/download/0.4.5/buildifier +chmod +x buildifier +sudo mv buildifier /usr/local/bin/. -bazel build buildifier:buildifier --spawn_strategy=standalone --genrule_strategy=standalone -sudo cp bazel-bin/buildifier/buildifier /usr/local/bin/ - -popd -rm -rf ${BUILDIFIER_DIR} +# Download buildozer. +wget https://github.com/bazelbuild/buildtools/releases/download/0.4.5/buildozer +chmod +x buildozer +sudo mv buildozer /usr/local/bin/. diff --git a/tensorflow/tools/ci_build/install/install_golang.sh b/tensorflow/tools/ci_build/install/install_golang.sh index fef203b869704155e9c3b226bbef4af63e2e706c..88bc2960e347c9a0fb26b04863d359598edcce10 100755 --- a/tensorflow/tools/ci_build/install/install_golang.sh +++ b/tensorflow/tools/ci_build/install/install_golang.sh @@ -16,7 +16,7 @@ set -ex -GOLANG_URL="https://storage.googleapis.com/golang/go1.7.5.linux-amd64.tar.gz" +GOLANG_URL="https://storage.googleapis.com/golang/go1.8.3.linux-amd64.tar.gz" sudo mkdir -p /usr/local wget -q -O - "${GOLANG_URL}" | sudo tar -C /usr/local -xz diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 8768852dc7e847e29a6089f963d33bfb137675d7..44fc21df9458c0880d1972603c93e1590e2b0643 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -44,8 +44,8 @@ pip2 install --upgrade markdown==2.6.8 pip3 install --upgrade markdown==2.6.8 # Install protobuf. -pip2 install --upgrade protobuf==3.2.0 -pip3 install --upgrade protobuf==3.2.0 +pip2 install --upgrade protobuf==3.3.0 +pip3 install --upgrade protobuf==3.3.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -86,5 +86,6 @@ pip2 install mock pip2 install portpicker pip3 install portpicker -pip2 install backports.weakref==1.0rc1 -pip3 install backports.weakref==1.0rc1 +# TensorFlow Serving integration tests require the following: +pip2 install grpcio +pip3 install grpcio diff --git a/tensorflow/tools/ci_build/install/install_proto3.sh b/tensorflow/tools/ci_build/install/install_proto3.sh index 773c89b70bbe64f7923645ea5c3c532e52d02c2d..7934002b2c982cd10216016f8614b70b77b58e29 100755 --- a/tensorflow/tools/ci_build/install/install_proto3.sh +++ b/tensorflow/tools/ci_build/install/install_proto3.sh @@ -17,9 +17,9 @@ # Install protobuf3. # Select protobuf version. -PROTOBUF_VERSION="3.2.0" +PROTOBUF_VERSION="3.3.0" protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g') -local_protobuf_ver=$(protoc --version | awk '{print $2}') +local_protobuf_ver=$(protoc --version) local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g') if [[ -z $local_protobuf_ver_flat ]]; then local_protobuf_ver_flat=0 @@ -30,7 +30,7 @@ if (( $local_protobuf_ver_flat < $protobuf_ver_flat )); then PROTOBUF_ZIP=$(basename "${PROTOBUF_URL}") UNZIP_DEST="google-protobuf" - wget -q "${PROTOBUF_URL}" + wget "${PROTOBUF_URL}" unzip "${PROTOBUF_ZIP}" -d "${UNZIP_DEST}" cp "${UNZIP_DEST}/bin/protoc" /usr/local/bin/ diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index edfc4e3a98f7c613fd7db80f8426514ad09f4f72..706d414746408d0bc918fb7408985561dac70d7c 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -64,7 +64,7 @@ set -e pip3.5 install --upgrade six==1.10.0 # Install protobuf. -pip3.5 install --upgrade protobuf==3.2.0 +pip3.5 install --upgrade protobuf==3.3.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -90,5 +90,4 @@ pip3.5 install portpicker pip3.5 install werkzeug -pip3.5 install backports.weakref==1.0rc1 - +pip3.5 install grpcio diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh index 467e4ab7e53ebd1c6985bcc908c9efdda10cef17..ca840796543a055d58359449d43944720635f0c4 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh @@ -30,10 +30,10 @@ export TF_NEED_HDFS=0 export TF_NEED_CUDA=0 # Only running cc tests, python version does not matter. export PYTHON_BIN_PATH=`which python` -yes "" | ./configure +$PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=cc -k \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh index e2bbc0e8c0be0d1069eb85364ba8a137b950cb3a..5c82c9efafa14e8491d50a02f35c0498a8f9ef79 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh @@ -29,10 +29,10 @@ export TF_NEED_GCP=0 export TF_NEED_HDFS=0 export TF_NEED_CUDA=0 export PYTHON_BIN_PATH=`which python2` -yes "" | ./configure +$PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=py -k \ +bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh index a03cab0cca5c375e668a2adeae64c48ac2b217a0..7155636a53fa9333945fdec0f0582c745db8ba17 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh @@ -29,10 +29,10 @@ export TF_NEED_GCP=0 export TF_NEED_HDFS=0 export TF_NEED_CUDA=0 export PYTHON_BIN_PATH=`which python3` -yes "" | ./configure +$PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu,-benchmark-test -k \ +bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --test_output=errors -- \ //tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh index 32de5cea200d4a43e5885364a9aeeafd2fa51af6..218d2a899135401d6fcc79677dd5ab3703034919 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh @@ -29,10 +29,10 @@ export TF_NEED_GCP=0 export TF_NEED_HDFS=0 export TF_NEED_CUDA=0 export PYTHON_BIN_PATH=`which python3` -yes "" | ./configure +$PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=py -k \ +bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh index 6acc26213835c0d2924f9ee0a31a80790bf5d75e..dff72c25bf7a6d9b0593391fbc090fd8a8ab537f 100755 --- a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh @@ -32,10 +32,10 @@ export PYTHON_BIN_PATH=`which python3` export TF_NEED_CUDA=1 export TF_CUDA_COMPUTE_CAPABILITIES=3.7 -yes "" | ./configure +$PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \ +bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \ --test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --build_tests_only --test_output=errors --local_test_jobs=8 \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh index e73fe046c967b0bb3db6eb5b109516c0d207a1e4..a36a8445afdebea15cf1fcf3c73d15ef4200a090 100755 --- a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh @@ -32,10 +32,10 @@ export PYTHON_BIN_PATH=`which python3` export TF_NEED_CUDA=1 export TF_CUDA_COMPUTE_CAPABILITIES=3.7 -yes "" | ./configure +$PYTHON_BIN_PATH configure.py # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \ +bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \ --test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --build_tests_only --test_output=errors --local_test_jobs=8 \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh index e5f4a22f7ade7eb5c260a7a486cd5d3fa75d5859..0ee894e2c44e8115148612191c949f6f4b0d42ba 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh @@ -30,11 +30,10 @@ export TF_NEED_GCP=0 export TF_NEED_HDFS=0 export TF_NEED_CUDA=0 export PYTHON_BIN_PATH=$(which python2) -yes "" | ./configure +$PYTHON_BIN_PATH configure.py which bazel -bazel test --test_tag_filters=-gpu,-benchmark-test,-nomac \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ --test_timeout 300,450,1200,3600 \ --test_size_filters=small,medium \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ - //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... \ - -//tensorflow/tensorboard/... + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/protobuf/protobuf_optimized_pip.sh b/tensorflow/tools/ci_build/protobuf/protobuf_optimized_pip.sh index 59ba71f5df77fd967e3699bce628adc49c7893ee..3e31aa1ce106531a32d0d8860de87a9aa490ae0c 100755 --- a/tensorflow/tools/ci_build/protobuf/protobuf_optimized_pip.sh +++ b/tensorflow/tools/ci_build/protobuf/protobuf_optimized_pip.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -PROTOBUF_VERSION="3.2.0" +PROTOBUF_VERSION="3.3.1" PYTHON_BIN=${PYTHON_BIN:-python} DIR=${PWD}/protobuf diff --git a/tensorflow/tools/ci_build/update_version.sh b/tensorflow/tools/ci_build/update_version.sh index 682f5329f58fffa5f2030c7e33db14bd3e343165..b707ee338a2786ce3946d9e3d34da311b9f512f5 100755 --- a/tensorflow/tools/ci_build/update_version.sh +++ b/tensorflow/tools/ci_build/update_version.sh @@ -130,12 +130,6 @@ if [[ ${OLD_MAJOR} != ${MAJOR} ]] || [[ ${OLD_MINOR} != ${MINOR} ]]; then echo "Detected Major.Minor change. "\ "Updating pattern ${OLD_R_MAJOR_MINOR} to ${R_MAJOR_MINOR} in additional files" - # Update tensorflow/tensorboard/README.md - TENSORBOARD_README_MD="${TF_SRC_DIR}/tensorboard/README.md" - check_existence file "${TENSORBOARD_README_MD}" - sed -i -r -e "s/${OLD_R_MAJOR_MINOR}/${R_MAJOR_MINOR}/g" \ - "${TENSORBOARD_README_MD}" - # Update dockerfiles DEVEL_DOCKERFILE="${TF_SRC_DIR}/tools/docker/Dockerfile.devel" check_existence file "${DEVEL_DOCKERFILE}" diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh index 8853dc53b17b5b5f1dda096817c67723fdbefcc4..05392c27248f6603f61d59358887867cd9816550 100644 --- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh +++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh @@ -56,3 +56,7 @@ export PATH="/c/tools/cuda/bin:$PATH" # Set the common build options on Windows export BUILD_OPTS='--copt=-w --host_copt=-w --verbose_failures --experimental_ui' + +# Build TF with wrapper-less CROSSTOOL +# TODO(pcloudy): Remove this after wrapper-less CROSSTOOL becomes default +export NO_MSVC_WRAPPER=1 diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh index cc157c33f501c8da1e17656c05232458a8c6aaac..7cb81c20f02edc4593591cba1be2bfd2074751e4 100644 --- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh @@ -60,7 +60,7 @@ reinstall_tensorflow_pip ${PIP_NAME} # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow -# GPU tests are very flaky when running concurently, so set local_test_jobs=1 +# GPU tests are very flaky when running concurrently, so set local_test_jobs=1 bazel test -c opt --config=win-cuda $BUILD_OPTS -k --test_output=errors \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu \ diff --git a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh index 1106413071393be8cd60c88887bffa7ef673dc08..4a2f954dc957a4ba357437247646bf0c323f4e0c 100755 --- a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh +++ b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh @@ -33,8 +33,9 @@ export TF_NEED_CUDA=1 export TF_ENABLE_XLA=1 export TF_CUDA_COMPUTE_CAPABILITIES=3.7 -yes "" | ./configure +$PYTHON_BIN_PATH configure.py +bazel clean # Run bazel test command. Double test timeouts to avoid flakes. bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index f92edd0dd8863fa7a3a6ad764a895370d48a5958..8a8667957ae4acb97356d4a141edd422509b48c7 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -15,6 +15,7 @@ py_library( name = "public_api", srcs = ["public_api.py"], srcs_version = "PY2AND3", + deps = ["//tensorflow/python:util"], ) py_test( @@ -32,6 +33,7 @@ py_library( name = "traverse", srcs = ["traverse.py"], srcs_version = "PY2AND3", + deps = ["//tensorflow/python:util"], ) py_test( diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD index fb40cf0833f08fc142aec18fe8940ce836453906..19959ea6d260d5aded5a3f37850025f6722d82ee 100644 --- a/tensorflow/tools/compatibility/BUILD +++ b/tensorflow/tools/compatibility/BUILD @@ -24,7 +24,9 @@ py_test( srcs_version = "PY2AND3", deps = [ "tf_upgrade", - "//tensorflow:tensorflow_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "@six_archive//:six", ], ) diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py index f7dbfea7fb0f3463cd708cde8762eb28b69b05a1..e40ecb43f9a00bee7309895969ff65e48b95b4e9 100644 --- a/tensorflow/tools/dist_test/python/mnist_replica.py +++ b/tensorflow/tools/dist_test/python/mnist_replica.py @@ -17,7 +17,7 @@ A simple softmax model with one hidden layer is defined. The parameters (weights and biases) are located on one parameter server (ps), while the ops -are executed on two worker nodes by default. The TF sessions also run on the +are executed on two worker nodes by default. The TF sessions also run on the worker node. Multiple invocations of this script can be done in parallel, with different values for --task_index. There should be exactly one invocation with @@ -123,9 +123,7 @@ def main(unused_argv): is_chief = (FLAGS.task_index == 0) if FLAGS.num_gpus > 0: - if FLAGS.num_gpus < num_workers: - raise ValueError("number of gpus is less than number of workers") - # Avoid gpu allocation conflict: now allocate task_num -> #gpu + # Avoid gpu allocation conflict: now allocate task_num -> #gpu # for each worker in the corresponding machine gpu = (FLAGS.task_index % FLAGS.num_gpus) worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu) diff --git a/tensorflow/tools/dist_test/scripts/BUILD b/tensorflow/tools/dist_test/scripts/BUILD index c329f0bbe8779fe300e601a1f41d6c123688815a..ce2fa5c743ece40eae10b30f4b2626a9cfada147 100644 --- a/tensorflow/tools/dist_test/scripts/BUILD +++ b/tensorflow/tools/dist_test/scripts/BUILD @@ -17,6 +17,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":k8s_tensorflow_lib", - "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile index 5b3f1f936a48bb448b712152c57c095226efea8e..07a972400df46f59c2d24b7b8e99bd690659b83a 100644 --- a/tensorflow/tools/docker/Dockerfile +++ b/tensorflow/tools/docker/Dockerfile @@ -24,14 +24,15 @@ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ rm get-pip.py RUN pip --no-cache-dir install \ + Pillow \ + h5py \ ipykernel \ jupyter \ matplotlib \ numpy \ + pandas \ scipy \ sklearn \ - pandas \ - Pillow \ && \ python -m ipykernel.kernelspec diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index 38a67f80aae5c6ad66639c24059ac50a3c6f3220..1b97c0d10830f92118bf6b597558c107a0182a92 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -72,7 +72,7 @@ RUN mkdir /bazel && \ RUN git clone https://github.com/tensorflow/tensorflow.git && \ cd tensorflow && \ - git checkout r1.2 + git checkout r1.3 WORKDIR /tensorflow # TODO(craigcitro): Don't install the pip package, since it makes it diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index d0a038a9db61c97643678d9fbca8974df0f84c8f..80b45ae70473ccfbf8869d846a080d15dfcfd905 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -19,6 +19,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ zlib1g-dev \ openjdk-8-jdk \ openjdk-8-jre-headless \ + wget \ && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -72,7 +73,7 @@ RUN mkdir /bazel && \ RUN git clone https://github.com/tensorflow/tensorflow.git && \ cd tensorflow && \ - git checkout r1.2 + git checkout r1.3 WORKDIR /tensorflow # Configure the build for our CUDA configuration. diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu index 3ba1e963f92a0fd7294a36288785545962f40146..da83a300580b660bd2cea890eff8acc8a96103b2 100644 --- a/tensorflow/tools/docker/Dockerfile.gpu +++ b/tensorflow/tools/docker/Dockerfile.gpu @@ -24,14 +24,15 @@ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ rm get-pip.py RUN pip --no-cache-dir install \ + Pillow \ + h5py \ ipykernel \ jupyter \ matplotlib \ numpy \ + pandas \ scipy \ sklearn \ - pandas \ - Pillow \ && \ python -m ipykernel.kernelspec diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md index 3e45ae362c71021ec1931c59acd1c38fbfac8fc6..3780bde2beeac389437627b012d95be7aa9dbbd2 100644 --- a/tensorflow/tools/docker/README.md +++ b/tensorflow/tools/docker/README.md @@ -54,6 +54,30 @@ for additional containers, such as release candidates or nightly builds. ## Rebuilding the containers -Just pick the dockerfile corresponding to the container you want to build, and run +Building TensorFlow Docker containers should be done through the +[parameterized_docker_build.sh](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/parameterized_docker_build.sh) +script. The raw Dockerfiles should not be used directly as they contain strings +to be replaced by the script during the build. - $ docker build --pull -t $USER/tensorflow-suffix -f Dockerfile.suffix . +To use the script, specify the container type (`CPU` vs. `GPU`), the desired +Python version (`PYTHON2` vs. `PYTHON3`) and whether the developer Docker image +is to be built (`NO` vs. `YES`). In addition, you need to specify the central +location from where the pip package of TensorFlow will be downloaded. + +For example, to build a CPU-only non-developer Docker image for Python 2, using +TensorFlow's nightly pip package: + +``` bash +export TF_DOCKER_BUILD_IS_DEVEL=NO +export TF_DOCKER_BUILD_TYPE=CPU +export TF_DOCKER_BUILD_PYTHON_VERSION=PYTHON2 + +export NIGHTLY_VERSION="1.head" +export TF_DOCKER_BUILD_CENTRAL_PIP=$(echo ${TF_DOCKER_BUILD_PYTHON_VERSION} | sed s^PYTHON2^http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=${TF_DOCKER_BUILD_PYTHON_VERSION},label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-${NIGHTLY_VERSION}-cp27-cp27mu-manylinux1_x86_64.whl^ | sed s^PYTHON3^http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-${NIGHTLY_VERSION}-cp35-cp35m-manylinux1_x86_64.whl^) + +tensorflow/tools/docker/parameterized_docker_build.sh +``` + +If successful, the image will be tagged as `${USER}/tensorflow:latest` by default. + +Rebuilding GPU images requires [nvidia-docker](https://github.com/NVIDIA/nvidia-docker). diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 8e27b133c2fa33a8f6366b0f94a596cf1ca7c1a2..8f10bc9e0ca3c947b8ca75663444309088e0513e 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -37,6 +37,7 @@ py_library( srcs = ["parser.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + deps = ["@com_github_andreif_codegen"], ) py_test( @@ -44,7 +45,6 @@ py_test( size = "small", srcs = ["parser_test.py"], srcs_version = "PY2AND3", - tags = ["manual"], deps = [ ":parser", "//tensorflow/python:platform_test", @@ -78,13 +78,10 @@ py_test( size = "small", srcs = ["generate_lib_test.py"], srcs_version = "PY2AND3", - tags = ["manual"], deps = [ ":generate_lib", ":parser", - "//tensorflow:tensorflow_py", "//tensorflow/python:platform_test", - "//tensorflow/python/debug:debug_py", ], ) @@ -105,7 +102,12 @@ py_test( srcs = ["build_docs_test.py"], data = ["//tensorflow:docs_src"], srcs_version = "PY2AND3", - tags = ["manual"], + tags = [ + # No reason to run sanitizers for this test. + "noasan", + "nomsan", + "notsan", + ], deps = [ ":generate_lib", "//tensorflow:tensorflow_py", diff --git a/tensorflow/tools/docs/build_docs_test.py b/tensorflow/tools/docs/build_docs_test.py index d28dd93b9a8d5eb19af414622c1d1b22516f9c1c..ae293f6576456ecdbb8a4b1ee4e8e4f40482ad94 100644 --- a/tensorflow/tools/docs/build_docs_test.py +++ b/tensorflow/tools/docs/build_docs_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import os +import sys +import textwrap import tensorflow as tf from tensorflow.python import debug as tf_debug @@ -29,19 +31,40 @@ from tensorflow.tools.docs import generate_lib class Flags(object): resource_root = resource_loader.get_root_dir_with_all_resources() - src_dir = os.path.join(resource_root, 'third_party/tensorflow/docs_src') - base_dir = os.path.join(resource_root, 'third_party/tensorflow/') + src_dir = os.path.join(resource_root, 'tensorflow/docs_src') + base_dir = os.path.join(resource_root, 'tensorflow/') output_dir = googletest.GetTempDir() class BuildDocsTest(googletest.TestCase): def testBuildDocs(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + doc_generator = generate_lib.DocGenerator() doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)]) - status = doc_generator.build(Flags()) + try: + status = doc_generator.build(Flags()) + except RuntimeError as e: + if not e.args[0].startswith('Modules nested too deep'): + raise + + msg = textwrap.dedent("""\ + %s + + **************************************************************** + If this test fails here, you have most likely introduced an + unsealed module. Make sure to use `remove_undocumented` or similar + utilities to avoid leaking symbols. See above for more information + on the exact point of failure. + **************************************************************** + """ % e.args[0]) + + raise RuntimeError(msg) if status: self.fail('Found %s Errors!' % status) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 99872e1d8446ab84bcf77caeb86003d86db85e52..bbeb3921d7b75a9d06d99e0131e1886af3849f2a 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse import os +import sys import six @@ -90,6 +91,7 @@ def write_docs(output_dir, parser_config, yaml_toc): # Parse and write Markdown pages, resolving cross-links (@{symbol}). for full_name, py_object in six.iteritems(parser_config.index): + parser_config.reference_resolver.current_doc_full_name = full_name if full_name in parser_config.duplicate_of: continue @@ -181,7 +183,7 @@ def add_dict_to_dict(add_from, add_to): # Exclude some libaries in contrib from the documentation altogether. def _get_default_private_map(): - return {} + return {'tf.test': ['mock']} # Exclude members of some libaries. @@ -390,6 +392,9 @@ def _other_docs(src_dir, output_dir, reference_resolver): print('Skipping excluded file %s...' % base_name) continue full_in_path = os.path.join(dirpath, base_name) + + reference_resolver.current_doc_full_name = full_in_path + suffix = os.path.relpath(path=full_in_path, start=src_dir) full_out_path = os.path.join(output_dir, suffix) if not base_name.endswith('.md'): @@ -415,6 +420,8 @@ class DocGenerator(object): """Main entry point for generating docs.""" def __init__(self): + if sys.version_info >= (3, 0): + sys.exit('Doc generation is not supported from python3.') self.argument_parser = argparse.ArgumentParser() self._py_modules = None self._private_map = _get_default_private_map() @@ -442,7 +449,7 @@ class DocGenerator(object): '--base_dir', type=str, default=default_base_dir, - help='Base directory to to strip from file names referenced in docs.') + help='Base directory to strip from file names referenced in docs.') def parse_known_args(self): flags, _ = self.argument_parser.parse_known_args() @@ -505,7 +512,6 @@ class DocGenerator(object): write_docs(output_dir, parser_config, yaml_toc=self.yaml_toc) _other_docs(flags.src_dir, flags.output_dir, reference_resolver) - if parser.all_errors: - print('Errors during processing:\n ' + '\n '.join(parser.all_errors)) - return 1 - return 0 + parser_config.reference_resolver.log_errors() + + return parser_config.reference_resolver.num_errors() diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py index 6e5deb6a36ed7d7d8b51f28e7ed3d9a680fce13b..1ceaf31f1c3b83e2c2cb3c0d2022ce98781aed4b 100644 --- a/tensorflow/tools/docs/generate_lib_test.py +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -21,9 +21,6 @@ from __future__ import print_function import os import sys -import tensorflow as tf - -from tensorflow.python import debug as tf_debug from tensorflow.python.platform import googletest from tensorflow.tools.docs import generate_lib from tensorflow.tools.docs import parser @@ -54,23 +51,10 @@ class DummyVisitor(object): class GenerateTest(googletest.TestCase): - def test_extraction(self): - py_modules = [('tf', tf), ('tfdbg', tf_debug)] - - try: - generate_lib.extract(py_modules, - generate_lib._get_default_private_map(), - generate_lib._get_default_do_not_descend_map()) - except RuntimeError: - print('*****************************************************************') - print('If this test fails, you have most likely introduced an unsealed') - print('module. Make sure to use remove_undocumented or similar utilities') - print('to avoid leaking symbols. See below for more information on the') - print('failure.') - print('*****************************************************************') - raise - def test_write(self): + if sys.version_info >= (3, 0): + self.skipTest('Warning: Doc generation is not supported from python3.') + module = sys.modules[__name__] index = { diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 7ae1d2abd9af813d29e527f447b6ce21c8e72b82..563e5be814ce227279b4f55e6050ff902de54487 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -24,6 +24,7 @@ import functools import json import os import re +import sys import codegen import six @@ -35,13 +36,36 @@ from tensorflow.python.util import tf_inspect # A regular expression capturing a python indentifier. IDENTIFIER_RE = '[a-zA-Z_][a-zA-Z0-9_]*' -# Log of all reported errors -all_errors = [] +class _Errors(object): + """A collection of errors.""" -def log_error(s): - all_errors.append(s) - print('ERROR:', s) + def __init__(self): + self._errors = [] + + def log_all(self): + """Log all the collected errors to the standard error.""" + template = 'ERROR:\n output file name: %s\n %s\n\n' + + for full_name, message in self._errors: + print(template % (full_name, message), file=sys.stderr) + + def append(self, full_name, message): + """Add an error to the collection. + + Args: + full_name: The path to the file in which the error occurred. + message: The message to display with the error. + """ + self._errors.append((full_name, message)) + + def __len__(self): + return len(self._errors) + + def __eq__(self, other): + if not isinstance(other, _Errors): + return False + return self._errors == other._errors # pylint: disable=protected-access def documentation_path(full_name): @@ -107,6 +131,18 @@ class ReferenceResolver(object): self._all_names = set(is_class.keys()) self._py_module_names = py_module_names + self.current_doc_full_name = None + self._errors = _Errors() + + def add_error(self, message): + self._errors.append(self.current_doc_full_name, message) + + def log_errors(self): + self._errors.log_all() + + def num_errors(self): + return len(self._errors) + @classmethod def from_visitor(cls, visitor, doc_index, **kwargs): """A factory function for building a ReferenceResolver from a visitor. @@ -153,7 +189,8 @@ class ReferenceResolver(object): for key, value in self.__dict__.items(): # Drop these two fields. `_doc_index` is not serializable. `_all_names` is # generated by the constructor. - if key in ('_doc_index', '_all_names'): + if key in ('_doc_index', '_all_names', + '_errors', 'current_doc_full_name'): continue # Strip off any leading underscores on field names as these are not @@ -186,10 +223,10 @@ class ReferenceResolver(object): Returns: `string`, with "@{symbol}" references replaced by Markdown links. """ - return re.sub(SYMBOL_REFERENCE_RE, - lambda match: self._one_ref(match.group(1), # pylint: disable=g-long-lambda - relative_path_to_root), - string) + def one_ref(match): + return self._one_ref(match, relative_path_to_root) + + return re.sub(SYMBOL_REFERENCE_RE, one_ref, string) def python_link(self, link_text, ref_full_name, relative_path_to_root, code_ref=True): @@ -250,9 +287,8 @@ class ReferenceResolver(object): # Check whether this link exists if master_name not in self._all_names: - # TODO(josh11b): Make error reporting more uniform. - print('ERROR: Cannot make link to %s (original: %s): Not in index.' % - (master_name, ref_full_name)) + message = 'Cannot make link to "%s": Not in index.' % master_name + self.add_error(message) return 'BROKEN_LINK' # If this is a member of a class, link to the class page with an anchor. @@ -270,8 +306,10 @@ class ReferenceResolver(object): return os.path.join(relative_path_to_root, ref_path) - def _one_ref(self, string, relative_path_to_root): + def _one_ref(self, match, relative_path_to_root): """Return a link for a single "@{symbol}" reference.""" + string = match.group(1) + # Look for link text after $. dollar = string.rfind('$') if dollar > 0: # Ignore $ in first character @@ -303,8 +341,8 @@ class ReferenceResolver(object): code_ref=not manual_link_text) # Error! - log_error('Did not understand "@{%s}"' % string) - return 'ERROR:%s' % string + self.add_error('Did not understand "%s"' % match.group(0)) + return 'BROKEN_LINK' def _doc_link(self, string, link_text, manual_link_text, relative_path_to_root): @@ -330,7 +368,7 @@ class ReferenceResolver(object): def _doc_missing(self, string, unused_hash_tag, link_text, unused_manual_link_text, unused_relative_path_to_root): """Generate an error for unrecognized @{$...} references.""" - log_error('Handle doc reference "@{$%s}"' % string) + self.add_error('Unknown Document "%s"' % string) return link_text def _cc_link(self, string, link_text, unused_manual_link_text, @@ -348,7 +386,7 @@ class ReferenceResolver(object): elif string == 'tensorflow::ops::Const': ret = 'namespace/tensorflow/ops.md#const' else: - log_error('Handle C++ reference "@{%s}"' % string) + self.add_error('C++ reference not understood: "%s"' % string) return 'TODO_C++:%s' % string # relative_path_to_root gets you to api_docs/python, we go from there # to api_docs/cc, and then add ret. @@ -469,7 +507,7 @@ def _parse_function_details(docstring): pairs = list(_gen_pairs(parts[1:])) function_details = [] - item_re = re.compile(r'^ (\w+):', re.MULTILINE) + item_re = re.compile(r'^ (\*?\*?\w+):', re.MULTILINE) for keyword, content in pairs: content = item_re.split(content) diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 3e02160130f1959484472ecc77e8b2e883294a1e..862f0acfa90fbc8ea7f5054b745c684783f1ff5a 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -491,13 +491,13 @@ Returns: class TestParseFunctionDetails(googletest.TestCase): - def testParseFunctionDetails(self): + def test_parse_function_details(self): docstring, function_details = parser._parse_function_details(RELU_DOC) self.assertEqual(len(function_details), 2) args = function_details[0] self.assertEqual(args.keyword, 'Args') - self.assertEmpty(args.header) + self.assertEqual(len(args.header), 0) self.assertEqual(len(args.items), 2) self.assertEqual(args.items[0][0], 'features') self.assertEqual(args.items[1][0], 'name') @@ -515,5 +515,60 @@ class TestParseFunctionDetails(googletest.TestCase): docstring + ''.join(str(detail) for detail in function_details)) +class TestGenerateSignature(googletest.TestCase): + + def test_known_object(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + known_object = object() + reverse_index = {id(known_object): 'location.of.object.in.api'} + + def example_fun(arg=known_object): # pylint: disable=unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index) + self.assertEqual(sig, ['arg=location.of.object.in.api']) + + def test_literals(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + def example_fun(a=5, b=5.0, c=None, d=True, e='hello', f=(1, (2, 3))): # pylint: disable=g-bad-name, unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index={}) + self.assertEqual( + sig, ['a=5', 'b=5.0', 'c=None', 'd=True', "e='hello'", 'f=(1, (2, 3))']) + + def test_dotted_name(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + # pylint: disable=g-bad-name + class a(object): + + class b(object): + + class c(object): + + class d(object): + + def __init__(self, *args): + pass + # pylint: enable=g-bad-name + + e = {'f': 1} + + def example_fun(arg1=a.b.c.d, arg2=a.b.c.d(1, 2), arg3=e['f']): # pylint: disable=unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index={}) + self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"]) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index fa2cf15cb16ec3396089b5f52ce8718fd05f94a0..cad0567b9e9acb586203cf105b40df9c0094bc61 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -26,14 +26,12 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", ], ) @@ -44,11 +42,46 @@ tf_cc_test( deps = [ ":transform_utils", "//tensorflow/cc:cc_ops", - "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "file_utils", + srcs = [ + "file_utils.cc", + ], + hdrs = [ + "file_utils.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +tf_cc_test( + name = "file_utils_test", + size = "small", + srcs = ["file_utils_test.cc"], + deps = [ + ":file_utils", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -60,6 +93,7 @@ cc_library( srcs = [ "add_default_attributes.cc", "backports.cc", + "fake_quantize_training.cc", "fold_batch_norms.cc", "fold_constants_lib.cc", "fold_old_batch_norms.cc", @@ -109,6 +143,7 @@ tf_cc_test( srcs = [ "add_default_attributes_test.cc", "backports_test.cc", + "fake_quantize_training_test.cc", "fold_batch_norms_test.cc", "fold_constants_test.cc", "fold_old_batch_norms_test.cc", @@ -152,6 +187,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":file_utils", ":transform_utils", ":transforms_lib", "//tensorflow/core:framework_internal", @@ -213,6 +249,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":file_utils", ":transform_utils", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -238,9 +275,9 @@ cc_binary( linkstatic = 1, visibility = ["//visibility:public"], deps = [ + ":file_utils", ":transform_utils", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -250,7 +287,12 @@ py_library( name = "transform_graph_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:pywrap_tensorflow"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:errors", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:util", + ], ) tf_py_test( diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index b4274e67df315aaa094413ff2576e7c58bd610ff..66e0ba60ebcb994eb20910b8db4bb96cbcf9e319 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -578,10 +578,14 @@ eight-bit form. ### quantize_weights -Args: None \ +Args: + +* minimum_size: Tensors with fewer elements than this won't be quantized +(defaults to 1024) + Prerequisites: None -Converts any large (more than 15 element) float Const op into an eight-bit +Converts any large (more than minimum_size) float Const op into an eight-bit equivalent, followed by a float conversion op so that the result is usable by subsequent nodes. This is mostly useful for [shrinking file sizes](#shrinking-file-size), but also helps with the more advanced @@ -760,7 +764,7 @@ heart, all of the transforms take in a valid GraphDef, make some changes, and output a new GraphDef. Each GraphDef is just a list of NodeDefs, each defining one node in the graph and its connections. You can find more information on the format at [this guide to TensorFlow model -files](https://www.tensorflow.org/versions/master/how_tos/tool_developers/index.html), +files](https://www.tensorflow.org/versions/master/extend/tool_developers/index.html), but for a simple example take a look at [tensorflow/tools/graph_transforms/rename_op.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/rename_op.cc), which implements the [rename_op](#rename_op) transform: diff --git a/tensorflow/tools/graph_transforms/compare_graphs.cc b/tensorflow/tools/graph_transforms/compare_graphs.cc index 8fce16337f7a875835c6f5e5aeaf19a6627a3a13..28a80a885f86fed1f0f30d0ecdc87c9dbb7ba27c 100644 --- a/tensorflow/tools/graph_transforms/compare_graphs.cc +++ b/tensorflow/tools/graph_transforms/compare_graphs.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/tools/graph_transforms/file_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training.cc b/tensorflow/tools/graph_transforms/fake_quantize_training.cc new file mode 100644 index 0000000000000000000000000000000000000000..321de47db1f1e5b305c91378917dab14f9912748 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fake_quantize_training.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/graph/quantize_training.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Rewrites the GraphDef for quantized training. +// Rewrites the forward pass to include the precision loss with quantization so +// the model can learn to deal with such loss and achieve better accuracy when +// it is quantized later for inference. +// Quantization range information is collected in FakeQuantizeWithMinMaxVars +// ops. +// +// TODO(suharshs): Provide instructions on converting the resulting graph for +// inference. +// TODO(suharshs): Implement this using the GTT rather than calling the old +// prototype function. +Status FakeQuantizeTraining(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + // TODO(suharshs): Make num_bits a parameter. + const int32 num_bits = 8; + // TODO(suharshs): Make quantization op a parameter? + const string quant_op_type = "FakeQuantWithMinMaxVars"; + + return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type, + output_graph_def); +} + +REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ea7f512c6760c2e7d7b5870f17df9361e2488f6 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status FakeQuantizeTraining(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class FakeQuantizeTrainingTest : public ::testing::Test {}; + +// For now, since the fake_quantize_training transform just calls the +// quantize_training rewrite from tensorflow/core/graph/quantize_training.h, +// we just test that the graph has been changed by the transform. +// TODO(suharshs): Once we implement the fake_quantize_training transform +// using the GTT, write proper tests of the transform here. +TEST_F(FakeQuantizeTrainingTest, TransformOccurred) { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor a_data(DT_FLOAT, TensorShape()); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape()); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef result; + TransformFuncContext context; + TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result)); + + // Test that the transformation resulted in a graph with more nodes. + EXPECT_GT(result.node_size(), graph_def.node_size()); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/file_utils.cc b/tensorflow/tools/graph_transforms/file_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..5649c971982bd7a3db2f856f4219c8f6cc1aa811 --- /dev/null +++ b/tensorflow/tools/graph_transforms/file_utils.cc @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/file_utils.h" + +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace graph_transforms { + +Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) { + string file_data; + Status load_file_status = + ReadFileToString(Env::Default(), file_name, &file_data); + if (!load_file_status.ok()) { + errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")"); + return load_file_status; + } + // Try to load in binary format first, and then try ascii if that fails. + Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def); + if (!load_status.ok()) { + if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) { + load_status = Status::OK(); + } else { + errors::AppendToMessage(&load_status, + " (both text and binary parsing failed for file ", + file_name, ")"); + } + } + return load_status; +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/file_utils.h b/tensorflow/tools/graph_transforms/file_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4737e95abcec3694d426e0c3c3a7112c2c5b6bd1 --- /dev/null +++ b/tensorflow/tools/graph_transforms/file_utils.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace graph_transforms { + +// First tries to load the file as a text protobuf, if that fails tries to parse +// it as a binary protobuf, and returns an error if both fail. +Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def); + +} // namespace graph_transforms +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ diff --git a/tensorflow/tools/graph_transforms/file_utils_test.cc b/tensorflow/tools/graph_transforms/file_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c898ba0f1683d21e69e5be2fa6a8ab60bb10e31 --- /dev/null +++ b/tensorflow/tools/graph_transforms/file_utils_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/file_utils.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { +namespace graph_transforms { + +class FileUtilsTest : public ::testing::Test { + protected: + void TestLoadTextOrBinaryGraphFile() { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + const int width = 10; + + auto root = tensorflow::Scope::NewRootScope(); + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + const string text_file = + io::JoinPath(testing::TmpDir(), "text_graph.pbtxt"); + TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def)); + + const string binary_file = + io::JoinPath(testing::TmpDir(), "binary_graph.pb"); + TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def)); + + const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb"); + TF_ASSERT_OK( + WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto...")); + + GraphDef text_graph_def; + TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def)); + string text_diff; + EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff)) + << text_diff; + + GraphDef binary_graph_def; + TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def)); + string binary_diff; + EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff)) + << binary_diff; + + GraphDef no_graph_def; + EXPECT_FALSE( + LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def) + .ok()); + + GraphDef bogus_graph_def; + EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok()); + } +}; + +TEST_F(FileUtilsTest, TestLoadTextOrBinaryGraphFile) { + TestLoadTextOrBinaryGraphFile(); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 79472ae554998ceae1eef73577c5b289a857690a..f97e4854183d24000717716ade5a9177a11ead5f 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -152,9 +152,19 @@ Status FoldConstants(const GraphDef& input_graph_def, &input_graph, context.input_names, context.output_names, {}, device_attributes, false /* use_function_convention */, &metadata)); bool was_mutated; - TF_RETURN_IF_ERROR(ConstantFold(ConstantFoldingOptions(), nullptr, - Env::Default(), nullptr, &input_graph, - &was_mutated)); + // Exclude specified nodes from constant folding. + ConstantFoldingOptions cf_opts; + if (context.params.count("exclude_op") > 0) { + const auto& excluded_nodes = context.params.at("exclude_op"); + const std::set excluded_nodes_set(excluded_nodes.begin(), + excluded_nodes.end()); + cf_opts.consider = [excluded_nodes_set](const Node* n) { + return excluded_nodes_set.find(n->op_def().name()) == + excluded_nodes_set.end(); + }; + } + TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr, + &input_graph, &was_mutated)); GraphDef folded_graph_def; input_graph.ToGraphDef(&folded_graph_def); GraphDef send_recvs_replaced; diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index 902f92952a6405ad6eed3f61364f6e127bfda8cb..14e2c01c7c2a5032860992d9a4956816cce1bed0 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/image_ops.h" #include "tensorflow/cc/ops/nn_ops.h" @@ -69,11 +71,46 @@ class ConstantFoldingTest : public ::testing::Test { test::FillIota(&placeholder_tensor, 1.0f); TestConstantFolding(graph_def, {{"placeholder_expect_remains", placeholder_tensor}}, - {"output_expect_remains"}); + {}, {"output_expect_remains"}); + } + + void TestOpExclusionAdd() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = + Const(root.WithOpName("a_expect_remains"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = + Const(root.WithOpName("b_expect_remains"), Input::Initializer(b_data)); + + Output add = Add(root.WithOpName("add_expect_remains"), a_const, b_const); + + Output placeholder = + Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + + Output mul = + Mul(root.WithOpName("output_expect_remains"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + Tensor placeholder_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&placeholder_tensor, 1.0f); + TestConstantFolding(graph_def, + {{"placeholder_expect_remains", placeholder_tensor}}, + {"Add"}, {"output_expect_remains"}); } void TestConstantFolding(const GraphDef& graph_def, std::vector > inputs, + std::vector excluded_ops, const std::vector& outputs) { std::unique_ptr unfolded_session( tensorflow::NewSession(tensorflow::SessionOptions())); @@ -87,6 +124,7 @@ class ConstantFoldingTest : public ::testing::Test { context.input_names.push_back(input.first); } context.output_names = outputs; + context.params["exclude_op"] = std::move(excluded_ops); TF_ASSERT_OK( graph_transforms::FoldConstants(graph_def, context, &folded_graph_def)); @@ -203,6 +241,8 @@ class ConstantFoldingTest : public ::testing::Test { TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); } +TEST_F(ConstantFoldingTest, TestOpExclusionAdd) { TestOpExclusionAdd(); } + TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); } TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); } diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc index 066727614c8a24329d9d2f45d9dfe946a51b322b..0978c336b49ce8cc72d9fc35af551a7f15ee697f 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc @@ -54,7 +54,7 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, GraphDef replaced_graph_def; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( current_graph_def, // clang-format off - {"BatchNormWithGlobalNormalization", // batch_norm_node + {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node { {"Conv2D", // conv_node { @@ -74,19 +74,33 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, std::vector* new_nodes) { // Find all the nodes we expect in the subgraph. const NodeDef& batch_norm_node = match.node; - CHECK_EQ("BatchNormWithGlobalNormalization", batch_norm_node.op()); + // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ + // by input order and attribute names. + CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" || + batch_norm_node.op() == "FusedBatchNorm"); + const bool is_fused = batch_norm_node.op() == "FusedBatchNorm"; + const int mean_idx = is_fused ? 3 : 1; + const int var_idx = is_fused ? 4 : 2; + const int beta_idx = is_fused ? 2 : 3; + const int gamma_idx = is_fused ? 1 : 4; + const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon"; + // FusedBatchNorm always scales after normalization. + const bool scale_after_normalization = + is_fused || + batch_norm_node.attr().at("scale_after_normalization").b(); + const NodeDef& conv_node = match.inputs[0].node; CHECK_EQ("Conv2D", conv_node.op()); const NodeDef& input_node = match.inputs[0].inputs[0].node; const NodeDef& weights_node = match.inputs[0].inputs[1].node; CHECK_EQ("Const", weights_node.op()); - const NodeDef& mean_node = match.inputs[1].node; + const NodeDef& mean_node = match.inputs[mean_idx].node; CHECK_EQ("Const", mean_node.op()); - const NodeDef& variance_node = match.inputs[2].node; + const NodeDef& variance_node = match.inputs[var_idx].node; CHECK_EQ("Const", variance_node.op()); - const NodeDef& beta_node = match.inputs[3].node; + const NodeDef& beta_node = match.inputs[beta_idx].node; CHECK_EQ("Const", beta_node.op()); - const NodeDef& gamma_node = match.inputs[4].node; + const NodeDef& gamma_node = match.inputs[gamma_idx].node; CHECK_EQ("Const", gamma_node.op()); // We have a set of vectors that we want to combine into a vector of @@ -98,9 +112,7 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, Tensor beta = GetNodeTensorAttr(beta_node, "value"); Tensor gamma = GetNodeTensorAttr(gamma_node, "value"); const float variance_epsilon = - batch_norm_node.attr().at("variance_epsilon").f(); - const bool scale_after_normalization = - batch_norm_node.attr().at("scale_after_normalization").b(); + batch_norm_node.attr().at(epsilon_attr).f(); // Make sure all the inputs really are vectors, with as many entries // as there are columns in the weights. @@ -119,16 +131,17 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, scale_values[i] = (1.0f / sqrtf(variance.flat()(i) + variance_epsilon)) * gamma.flat()(i); - offset_values[i] = 0.0f; } } else { for (int i = 0; i < weights_cols; ++i) { scale_values[i] = (1.0f / sqrtf(variance.flat()(i) + variance_epsilon)); - offset_values[i] = (-mean.flat()(i) * scale_values[i]) + - beta.flat()(i); } } + for (int i = 0; i < weights_cols; ++i) { + offset_values[i] = (-mean.flat()(i) * scale_values[i]) + + beta.flat()(i); + } // Multiply the original weights by the scale vector. auto weights_matrix = weights.flat_inner_dims(); diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index 1c4958d83c935e4b298b54461e820b15608d7b8e..3be9110b475f97087be18118d2ba0c52d6388c03 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -118,11 +119,92 @@ class FoldOldBatchNormsTest : public ::testing::Test { EXPECT_NE("BatchNormWithGlobalNormalization", node.op()); } } + + void TestFoldFusedBatchNorms() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + Tensor mean_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&mean_data, {10.0f, 20.0f}); + Output mean_op = + Const(root.WithOpName("mean_op"), Input::Initializer(mean_data)); + + Tensor variance_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&variance_data, {0.25f, 0.5f}); + Output variance_op = Const(root.WithOpName("variance_op"), + Input::Initializer(variance_data)); + + Tensor beta_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&beta_data, {0.1f, 0.6f}); + Output beta_op = + Const(root.WithOpName("beta_op"), Input::Initializer(beta_data)); + + Tensor gamma_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&gamma_data, {1.0f, 2.0f}); + Output gamma_op = + Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data)); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + NodeDef batch_norm_node; + batch_norm_node.set_op("FusedBatchNorm"); + batch_norm_node.set_name("output"); + AddNodeInput("conv_op", &batch_norm_node); + AddNodeInput("gamma_op", &batch_norm_node); + AddNodeInput("beta_op", &batch_norm_node); + AddNodeInput("mean_op", &batch_norm_node); + AddNodeInput("variance_op", &batch_norm_node); + SetNodeAttr("T", DT_FLOAT, &batch_norm_node); + SetNodeAttr("epsilon", 0.00001f, &batch_norm_node); + SetNodeAttr("is_training", false, &batch_norm_node); + *(original_graph_def.mutable_node()->Add()) = batch_norm_node; + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}}, + &fused_graph_def)); + + std::unique_ptr fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("FusedBatchNorm", node.op()); + } + } }; TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) { TestFoldOldBatchNorms(); } +TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNorms) { + TestFoldFusedBatchNorms(); +} + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc index da064377ac3f2919e0d0421099d9407a35518e22..2b85e7e83c6f3e2c8d0840f0b9eb0b4992a8b113 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -119,6 +119,13 @@ const std::vector& GetQuantizedOpList() { DT_QUINT8, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"ResizeBilinear", + {"align_corners"}, + {{"T", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {1}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"Relu6", {}, {{"Tinput", DT_QUINT8}}, diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc index d02655f3f9cb5093a9c542e90aef2f8069e6e1dd..eca263a1ae0dbfad51565b1d3d0d26b066704fc8 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc @@ -106,8 +106,8 @@ class QuantizeNodesTest : public ::testing::Test { // Reshape is not included here because it can be added as part of the // quantization process. const std::set quantizable_ops = { - "Add", "BiasAdd", "Concat", "Conv2D", "MatMul", - "Relu", "Relu6", "AvgPool", "MaxPool", "Mul"}; + "Add", "BiasAdd", "Concat", "Conv2D", "MatMul", "Relu", + "Relu6", "ResizeBilinear", "AvgPool", "MaxPool", "Mul"}; for (const NodeDef& node : quantized_graph_def.node()) { EXPECT_EQ(0, quantizable_ops.count(node.op())) << "Found quantizable node " << node.op() << " for node named " @@ -652,6 +652,33 @@ class QuantizeNodesTest : public ::testing::Test { EXPECT_EQ("requantize_op", node_map.at("final_dequantize")->input(0)); } + void TestQuantizeResizeBilinear() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor size_tensor(DT_INT32, TensorShape({2})); + test::FillValues(&size_tensor, {256, 256}); + + Output constant_op = Const(root.WithOpName("size_tensor_op"), + Input::Initializer(size_tensor)); + + Output placeholder_op = + Placeholder(root.WithOpName("placeholder_op"), DT_FLOAT); + + Output resize_bilinear_op = ResizeBilinear( + root.WithOpName("resize_bilinear_op"), placeholder_op, constant_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + Tensor input_tensor(DT_FLOAT, {1, 128, 128, 3}); + test::FillFn(&input_tensor, [](int) { return 100.0f; }); + + TestQuantizedVersusFloatGraph(float_graph_def, + {{"placeholder_op", input_tensor}}, + {"resize_bilinear_op"}); + } + void TestRemoveRedundantQuantizationWithMultipleOutputs() { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) @@ -1446,6 +1473,10 @@ TEST_F(QuantizeNodesTest, TestQuantizeAvgPool) { TestQuantizeAvgPool(); } TEST_F(QuantizeNodesTest, TestQuantizeReshape) { TestQuantizeReshape(); } +TEST_F(QuantizeNodesTest, TestQuantizeResizeBilinear) { + TestQuantizeResizeBilinear(); +} + TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantization) { TestRemoveRedundantQuantization(); } diff --git a/tensorflow/tools/graph_transforms/quantize_weights.cc b/tensorflow/tools/graph_transforms/quantize_weights.cc index 66d800f0da1f49a2026a71927d6910e18e87f2f5..cccae8a992a64b0f49798eda71513a2fe62ad656 100644 --- a/tensorflow/tools/graph_transforms/quantize_weights.cc +++ b/tensorflow/tools/graph_transforms/quantize_weights.cc @@ -35,11 +35,15 @@ namespace graph_transforms { Status QuantizeWeights(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { + int32 minimum_size; + TF_RETURN_IF_ERROR( + context.GetOneInt32Parameter("minimum_size", 1024, &minimum_size)); TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( input_graph_def, {"Const"}, - [](const NodeMatch& match, const std::set& input_nodes, - const std::set& output_nodes, - std::vector* new_nodes) { + [minimum_size](const NodeMatch& match, + const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { const NodeDef& old_const_node = match.node; if (!old_const_node.attr().count("dtype")) { return errors::InvalidArgument("No 'dtype' attribute for Const node ", @@ -58,7 +62,7 @@ Status QuantizeWeights(const GraphDef& input_graph_def, const size_t num_elements = old_tensor.NumElements(); // If this isn't a float constant, or it's too small, then reuse the // same node with no changes. - if ((old_dtype != DT_FLOAT) || (num_elements < 16)) { + if ((old_dtype != DT_FLOAT) || (num_elements < minimum_size)) { new_nodes->push_back(old_const_node); return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc index 63c5b5a64d915e99f929e83650ac3d1dd432c6af..e1828831db19e9b449239b08e12e6e78c473552f 100644 --- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc @@ -70,9 +70,12 @@ class QuantizeWeightsTest : public ::testing::Test { 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}, &original_graph_def); + TransformFuncContext context; + context.output_names = {"output"}; + context.params["minimum_size"] = {"16"}; GraphDef quantized_graph_def; - TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}}, - &quantized_graph_def)); + TF_ASSERT_OK( + QuantizeWeights(original_graph_def, context, &quantized_graph_def)); // Verify the structure of the quantized graph. std::map node_lookup; @@ -122,9 +125,12 @@ TEST_F(QuantizeWeightsTest, RangesAlwaysIncludeZero) { 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}, &original_graph_def); + TransformFuncContext context; + context.output_names = {"output"}; + context.params["minimum_size"] = {"16"}; GraphDef quantized_graph_def; - TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}}, - &quantized_graph_def)); + TF_ASSERT_OK( + QuantizeWeights(original_graph_def, context, &quantized_graph_def)); std::map node_lookup; MapNamesToNodes(quantized_graph_def, &node_lookup); diff --git a/tensorflow/tools/graph_transforms/rename_attribute_test.cc b/tensorflow/tools/graph_transforms/rename_attribute_test.cc index a0a33e9fc090acea176333ec840e2e6f438ca998..31619d82ad998a48dde7a3c73fba12a16a0360c2 100644 --- a/tensorflow/tools/graph_transforms/rename_attribute_test.cc +++ b/tensorflow/tools/graph_transforms/rename_attribute_test.cc @@ -43,17 +43,17 @@ class RenameAttributeTest : public ::testing::Test { mul_node1->set_op("Mul"); mul_node1->add_input("add_node2"); mul_node1->add_input("add_node3"); - AddNodeAttr("foo", 23, mul_node1); - AddNodeAttr("bar", "something", mul_node1); + AddNodeAttr("foo", 23, mul_node1); + AddNodeAttr("bar", "something", mul_node1); NodeDef* add_node2 = graph_def.add_node(); add_node2->set_name("add_node2"); add_node2->set_op("Add"); add_node2->add_input("const_node1"); add_node2->add_input("const_node2"); - AddNodeAttr("foo", 46, add_node2); - AddNodeAttr("bob", 23, add_node2); - AddNodeAttr("bar", "something else", add_node2); + AddNodeAttr("foo", 46, add_node2); + AddNodeAttr("bob", 23, add_node2); + AddNodeAttr("bar", "something else", add_node2); NodeDef* add_node3 = graph_def.add_node(); add_node3->set_name("add_node3"); diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc index c441a089ced86e3ad779ae782eeec2e7e59e1e22..937d8c09ff78b0bf8e668bcef978f8e8e4120fdb 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -78,13 +78,14 @@ void CreateConstNode(const Tensor& tensor, const string& name, node_def->set_name(name); SetNodeTensorAttr("value", tensor, node_def); } -} // namespace -Status SparsifyGather(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def) { +Status SparsifyGatherInternal(const GraphDef& input_graph_def, + const TransformFuncContext& context, + const OpTypePattern& pattern, + GraphDef* output_graph_def) { GraphDef current_graph_def = input_graph_def; bool any_match_found = false; + // The subgraphs may have overlapping components, therefore GraphMatcher // doesn't return all subgraphs in one round -- this has to be multi-round // update. @@ -94,17 +95,7 @@ Status SparsifyGather(const GraphDef& input_graph_def, std::vector init_table_node_names; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( - current_graph_def, // clang-format off - {"Gather", - { - {"Identity", - { - {"Const"} - } - }, - {"*"}, - } - }, // clang-format on + current_graph_def, pattern, [&any_match_found, &init_table_node_names]( const NodeMatch& match, const std::set& input_nodes, const std::set& output_nodes, @@ -143,6 +134,33 @@ Status SparsifyGather(const GraphDef& input_graph_def, // c. a `default_val` arg, valued at 0 // clang-format on const NodeDef& gather_node = match.node; + + // GatherV2 adds an "axis" parameter. sparsify_gather only supports + // axis 0 gathers. + if (gather_node.op() == "GatherV2") { + // Per the OpTypePattern, the 3rd input to Gather must be a Const. + const NodeDef& axis_node = match.inputs[2].node; + + Tensor axis_t; + TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value", &axis_t)); + int64 axis = 0; + if (axis_t.dtype() == DT_INT32) { + axis = axis_t.scalar()(); + } else if (axis_t.dtype() == DT_INT64) { + axis = axis_t.scalar()(); + } else { + return tensorflow::errors::FailedPrecondition( + "Gather axis was not int32 or int64."); + } + + if (axis != 0) { + return tensorflow::errors::FailedPrecondition( + "Transform only applicable to subgraph with GatherV2 over " + "axis 0. Found axis ", + axis, "."); + } + } + const NodeDef& const_node = match.inputs[0].inputs[0].node; DataType data_type; @@ -269,6 +287,45 @@ Status SparsifyGather(const GraphDef& input_graph_def, *output_graph_def = current_graph_def; return Status::OK(); } +} // namespace + +Status SparsifyGather(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + // clang-format off + const OpTypePattern gather_pattern = + {"Gather", + { + {"Identity", + { + {"Const"} + } + }, + {"*"}, + } + }; + const OpTypePattern gather_v2_pattern = + {"GatherV2", + { + {"Identity", + { + {"Const"} + } + }, + {"*"}, + // GatherV2's axis must be constant. + {"Const"}, + } + }; + // clang-format on + + GraphDef temp_output; + TF_RETURN_IF_ERROR(SparsifyGatherInternal(input_graph_def, context, + gather_pattern, &temp_output)); + + return SparsifyGatherInternal(temp_output, context, gather_v2_pattern, + output_graph_def); +} REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather); diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index 8d353d34763a7f362e4e3164b5a93c504bee6fbe..c999212d6931fda940f4fad6c2b199c0e82c37aa 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/public/session.h" @@ -33,19 +34,33 @@ Status SparsifyGather(const GraphDef& input_graph_def, class SparsifyGatherTest : public ::testing::Test { protected: - NodeDef* CreateNode(const string& name, const string& op, + NodeDef* CreateNode(const StringPiece name, const StringPiece op, const std::vector& inputs, GraphDef* graph_def) { NodeDef* node_def = graph_def->add_node(); - node_def->set_name(name); - node_def->set_op(op); + node_def->set_name(name.ToString()); + node_def->set_op(op.ToString()); std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) { node_def->add_input(input->name()); }); return node_def; } - void TestSinglePartitionConst() { + void MakeGather(StringPiece name, bool gather_v2, NodeDef* params, + NodeDef* indices, GraphDef* graph_def) { + if (gather_v2) { + NodeDef* axis_node = + CreateNode(strings::StrCat(name, "_axis"), "Const", {}, graph_def); + Tensor axis_t(DT_INT32, TensorShape({})); + axis_t.scalar()() = 0; + SetNodeTensorAttr("value", axis_t, axis_node); + CreateNode(name, "GatherV2", {params, indices, axis_node}, graph_def); + } else { + CreateNode(name, "Gather", {params, indices}, graph_def); + } + } + + void TestSinglePartitionConst(bool gather_v2) { GraphDef graph_def; // Build the graph. @@ -59,7 +74,7 @@ class SparsifyGatherTest : public ::testing::Test { NodeDef* identity_node = CreateNode("const/read", "Identity", {const_node}, &graph_def); - CreateNode("gather", "Gather", {identity_node, input_node}, &graph_def); + MakeGather("gather", gather_v2, identity_node, input_node, &graph_def); CreateNode("group_deps", "NoOp", {}, &graph_def); // Run the op. @@ -151,7 +166,7 @@ class SparsifyGatherTest : public ::testing::Test { node_lookup.at("group_deps")->input().end()); } - void TestMultiPartitionConst() { + void TestMultiPartitionConst(bool gather_v2) { // The 'ids' node is served input for two 'Gather's. GraphDef graph_def; @@ -177,8 +192,8 @@ class SparsifyGatherTest : public ::testing::Test { CreateNode("const1/read", "Identity", {const_node1}, &graph_def); NodeDef* identity_node2 = CreateNode("const2/read", "Identity", {const_node2}, &graph_def); - CreateNode("gather1", "Gather", {identity_node1, input_node}, &graph_def); - CreateNode("gather2", "Gather", {identity_node2, input_node}, &graph_def); + MakeGather("gather1", gather_v2, identity_node1, input_node, &graph_def); + MakeGather("gather2", gather_v2, identity_node2, input_node, &graph_def); // Run the op. GraphDef result; @@ -341,11 +356,13 @@ class SparsifyGatherTest : public ::testing::Test { }; TEST_F(SparsifyGatherTest, TestSinglePartitionConst) { - TestSinglePartitionConst(); + TestSinglePartitionConst(false); + TestSinglePartitionConst(true); } TEST_F(SparsifyGatherTest, TestMultiPartitionConst) { - TestMultiPartitionConst(); + TestMultiPartitionConst(false); + TestMultiPartitionConst(true); } } // namespace graph_transforms diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc index 4eb074998f71e8c1ff51ea64463ff35660bcedca..c0107014e2cf115aeafe78ca879c0cb169cb335b 100644 --- a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index 91670f54d49d057bbb5ff894247c79538877ef5f..6c404c8061e199ca37c2d97eefd4fdb235c6b49a 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -23,13 +23,16 @@ limitations under the License. // bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ // --in_graph=my_graph.pb +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/file_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { @@ -81,11 +84,17 @@ void PrintBenchmarkUsage(const std::vector& placeholders, shape = PartialTensorShape(shape_proto); } } - sizes.reserve(shape.dims()); - for (int i = 0; i < shape.dims(); ++i) { - sizes.push_back(shape.dim_size(i)); + string sizes_string; + if (shape.dims() == -1) { + // Unknown shapes can have -1 for dims, so leave these blank. + sizes_string = ""; + } else { + sizes.reserve(shape.dims()); + for (int i = 0; i < shape.dims(); ++i) { + sizes.push_back(shape.dim_size(i)); + } + sizes_string = str_util::Join(sizes, ","); } - string sizes_string = str_util::Join(sizes, ","); input_layer_shapes.push_back(sizes_string); } std::vector output_layers; @@ -116,7 +125,17 @@ Status PrintStructure(const GraphDef& graph) { TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph)); for (const NodeDef& node : sorted_graph.node()) { std::cout << node.name() << " (" << node.op() << "): [" - << str_util::Join(node.input(), ", ") << "]" << std::endl; + << str_util::Join(node.input(), ", ") << "]"; + if (node.op() == "Const") { + Tensor tensor; + if (node.attr().count("value") && + tensor.FromProto(node.attr().at("value").tensor())) { + std::cout << ", value=" << tensor.DebugString(); + } else { + LOG(WARNING) << "Decoding Tensor failed for node" << node.name(); + } + } + std::cout << std::endl; } return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc index e7694104cbded7581529904565a0d13f1a39efba..28387c2b48c06ecffd2afa0705a8dea5bc368460 100644 --- a/tensorflow/tools/graph_transforms/transform_graph.cc +++ b/tensorflow/tools/graph_transforms/transform_graph.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/file_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index 0ef517acc5bb6518a54501ea271c21439789da42..bd1e4c90c06f76bbac608940ab792b02e68890d4 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/public/session.h" namespace tensorflow { namespace graph_transforms { @@ -587,28 +586,6 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, return Status::OK(); } -Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) { - string file_data; - Status load_file_status = - ReadFileToString(Env::Default(), file_name, &file_data); - if (!load_file_status.ok()) { - errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")"); - return load_file_status; - } - // Try to load in binary format first, and then try ascii if that fails. - Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def); - if (!load_status.ok()) { - if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) { - load_status = Status::OK(); - } else { - errors::AppendToMessage(&load_status, - " (both text and binary parsing failed for file ", - file_name, ")"); - } - } - return load_status; -} - int TransformFuncContext::CountParameters(const string& name) const { if (params.count(name)) { return params.at(name).size(); diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 6ed549a9589af2ff287aa199b2cfb113e40bf871..c0fb4924123ca6637ccc18043aab8d9829a298eb 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -20,10 +20,12 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -107,8 +109,8 @@ void FilterGraphDef(const GraphDef& input_graph_def, std::function selector, GraphDef* output_graph_def); -// Creates a copy of the input graph, with all occurrences of the attributes with -// the names in the argument removed from the node defs. +// Creates a copy of the input graph, with all occurrences of the attributes +// with the names in the argument removed from the node defs. void RemoveAttributes(const GraphDef& input_graph_def, const std::vector& attributes, GraphDef* output_graph_def); @@ -131,10 +133,6 @@ Status IsGraphValid(const GraphDef& graph_def); Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, DataTypeVector* outputs); -// First tries to load the file as a text protobuf, if that fails tries to parse -// it as a binary protobuf, and returns an error if both fail. -Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph); - // This is used to spot particular subgraphs in a larger model. To use it, // create a pattern like: // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index d068254b35fd7331f79934139586d3f8d7cd0aff..b5bc2d75fd2726ff5d10026039c07cff7ede2797 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace graph_transforms { @@ -1066,50 +1064,6 @@ class TransformUtilsTest : public ::testing::Test { TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value)); EXPECT_TRUE(value); } - - void TestLoadTextOrBinaryGraphFile() { - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - const int width = 10; - - auto root = tensorflow::Scope::NewRootScope(); - Tensor a_data(DT_FLOAT, TensorShape({width})); - test::FillIota(&a_data, 1.0f); - Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); - GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - - const string text_file = - io::JoinPath(testing::TmpDir(), "text_graph.pbtxt"); - TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def)); - - const string binary_file = - io::JoinPath(testing::TmpDir(), "binary_graph.pb"); - TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def)); - - const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb"); - TF_ASSERT_OK( - WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto...")); - - GraphDef text_graph_def; - TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def)); - string text_diff; - EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff)) - << text_diff; - - GraphDef binary_graph_def; - TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def)); - string binary_diff; - EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff)) - << binary_diff; - - GraphDef no_graph_def; - EXPECT_FALSE( - LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def) - .ok()); - - GraphDef bogus_graph_def; - EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok()); - } }; TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } @@ -1206,9 +1160,5 @@ TEST_F(TransformUtilsTest, TestGetOneBoolParameter) { TestGetOneBoolParameter(); } -TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) { - TestLoadTextOrBinaryGraphFile(); -} - } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 51ba3b7a0be143a0186269678d508f4f0e95c55c..536437df2b6d1b9a16a9b4d1e218ab6bd01a14e2 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -4,6 +4,7 @@ package(default_visibility = ["//visibility:private"]) load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") +load("//third_party/mkl:build_defs.bzl", "if_mkl") genrule( name = "libtensorflow_proto", @@ -87,6 +88,7 @@ genrule( "//third_party/fft2d:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", @@ -100,10 +102,13 @@ genrule( "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@png_archive//:LICENSE", - "@protobuf//:LICENSE", + "@protobuf_archive//:LICENSE", "@snappy//:COPYING", "@zlib_archive//:zlib.h", - ], + ] + if_mkl([ + "//third_party/mkl:LICENSE", + "@mkl//:LICENSE", + ]), outs = ["include/tensorflow/c/LICENSE"], cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", tools = [":concat_licenses.sh"], @@ -117,6 +122,7 @@ genrule( "//third_party/fft2d:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", @@ -130,10 +136,13 @@ genrule( "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@png_archive//:LICENSE", - "@protobuf//:LICENSE", + "@protobuf_archive//:LICENSE", "@snappy//:COPYING", "@zlib_archive//:zlib.h", - ], + ] + if_mkl([ + "//third_party/mkl:LICENSE", + "@mkl//:LICENSE", + ]), outs = ["include/tensorflow/jni/LICENSE"], cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", tools = [":concat_licenses.sh"], diff --git a/tensorflow/tools/mlpbtxt/BUILD b/tensorflow/tools/mlpbtxt/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..fc63e9a0b73fd92c63cde5d60bdb9b984922f820 --- /dev/null +++ b/tensorflow/tools/mlpbtxt/BUILD @@ -0,0 +1,44 @@ +# Description: +# This package provides binaries that convert between multi-line and standard +# pbtxt (text-serialization of protocol message) files. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files([ + "LICENSE", + "placeholder.txt", +]) + +cc_binary( + name = "tomlpbtxt", + srcs = ["tomlpbtxt.cc"], + deps = [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", + ], +) + +cc_binary( + name = "frommlpbtxt", + srcs = ["frommlpbtxt.cc"], + deps = [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tools/mlpbtxt/frommlpbtxt.cc b/tensorflow/tools/mlpbtxt/frommlpbtxt.cc new file mode 100644 index 0000000000000000000000000000000000000000..643924b318d3fec850ebd6c8275a2eab4884a644 --- /dev/null +++ b/tensorflow/tools/mlpbtxt/frommlpbtxt.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +int Run(int argc, char** argv) { + string FLAGS_in = ""; + string FLAGS_out = ""; + + std::vector flag_list = { + Flag("in", &FLAGS_in, "Input multi-line proto text (.mlpbtxt) file name"), + Flag("out", &FLAGS_out, "Output proto text (.pbtxt) file name")}; + + // Parse the command-line. + const string usage = Flags::Usage(argv[0], flag_list); + const bool parse_ok = Flags::Parse(&argc, argv, flag_list); + if (argc != 1 || !parse_ok) { + printf("%s", usage.c_str()); + return 2; + } + + port::InitMain(argv[0], &argc, &argv); + + // Read the input file --in. + string in_contents; + Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents); + if (!s.ok()) { + printf("Error reading file %s: %s\n", FLAGS_in.c_str(), + s.ToString().c_str()); + return 1; + } + + // Write the output file --out. + const string out_contents = PBTxtFromMultiline(in_contents); + s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents); + if (!s.ok()) { + printf("Error writing file %s: %s\n", FLAGS_out.c_str(), + s.ToString().c_str()); + return 1; + } + + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { return tensorflow::Run(argc, argv); } diff --git a/tensorflow/tools/mlpbtxt/tomlpbtxt.cc b/tensorflow/tools/mlpbtxt/tomlpbtxt.cc new file mode 100644 index 0000000000000000000000000000000000000000..469be49ed3c966c671f1f45619d0a8d88fe519f1 --- /dev/null +++ b/tensorflow/tools/mlpbtxt/tomlpbtxt.cc @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +int Run(int argc, char** argv) { + string FLAGS_in = ""; + string FLAGS_out = ""; + string FLAGS_fields = "description"; + + std::vector flag_list = { + Flag("in", &FLAGS_in, "Input proto text (.pbtxt) file name"), + Flag("out", &FLAGS_out, + "Output multi-line proto text (.mlpbtxt) file name"), + Flag("fields", &FLAGS_fields, "Comma-separated list of field names")}; + + // Parse the command-line. + const string usage = Flags::Usage(argv[0], flag_list); + const bool parse_ok = Flags::Parse(&argc, argv, flag_list); + if (argc != 1 || !parse_ok) { + printf("%s", usage.c_str()); + return 2; + } + + // Parse the --fields option. + std::vector fields = + str_util::Split(FLAGS_fields, ',', str_util::SkipEmpty()); + if (fields.empty()) { + printf("--fields must be non-empty.\n%s", usage.c_str()); + return 2; + } + + port::InitMain(argv[0], &argc, &argv); + + // Read the input file --in. + string in_contents; + Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents); + if (!s.ok()) { + printf("Error reading file %s: %s\n", FLAGS_in.c_str(), + s.ToString().c_str()); + return 1; + } + + // Write the output file --out. + const string out_contents = PBTxtToMultiline(in_contents, fields); + s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents); + if (!s.ok()) { + printf("Error writing file %s: %s\n", FLAGS_out.c_str(), + s.ToString().c_str()); + return 1; + } + + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { return tensorflow::Run(argc, argv); } diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 798338d787551769d94afd9f774a23655a640086..4cd42d79c0600466684e87fcb3d8fd79f8f600c9 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -55,6 +55,15 @@ py_test( ], ) +py_binary( + name = "check_load_py_test", + srcs = ["check_load_py_test.py"], + data = [ + "//tensorflow:all_opensource_files", + ], + srcs_version = "PY2AND3", +) + # On Windows, python binary is a zip file of runfiles tree. # Add everything to its data dependency for generating a runfiles tree # for building the pip package on Windows. @@ -73,12 +82,9 @@ py_binary( "//tensorflow/python:util_example_parser_configuration", "//tensorflow/python/debug:debug_pip", "//tensorflow/python/saved_model", + "//tensorflow/python:spectral_ops_test_util", "//tensorflow/python/tools:tools_pip", # These targets don't build on Windows yet. Exclude them for now. - # rules_closure currently doesn't build on Windows due to - # https://github.com/bazelbuild/rules_closure/pull/206 - # Since tensorboard dependes on rules_closure, exclude tensorboard until it's fixed. - # "//tensorflow/tensorboard", # "//tensorflow/contrib/ndlstm", # "//tensorflow/contrib/slim", # "//tensorflow/contrib/slim/python/slim/nets:nets_pip", @@ -99,6 +105,7 @@ filegroup( "//third_party/hadoop:LICENSE.txt", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", @@ -112,17 +119,17 @@ filegroup( "@libxsmm_archive//:LICENSE", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", - "@nanopb_git//:LICENSE.txt", - "@org_html5lib//:LICENSE", - "@org_mozilla_bleach//:LICENSE", - "@org_pocoo_werkzeug//:LICENSE", - "@org_pythonhosted_markdown//:LICENSE.md", + "@grpc//third_party/nanopb:LICENSE.txt", "@png_archive//:LICENSE", - "@protobuf//:LICENSE", + "@protobuf_archive//:LICENSE", "@six_archive//:LICENSE", "@snappy//:COPYING", "@zlib_archive//:zlib.h", - ] + if_not_windows([ + "@org_python_pypi_backports_weakref//:LICENSE", + ] + if_mkl([ + "//third_party/mkl:LICENSE", + "@mkl//:LICENSE", + ]) + if_not_windows([ "@nccl_archive//:LICENSE.txt", ]) + tf_additional_license_deps(), ) @@ -141,11 +148,14 @@ sh_binary( ":included_headers", ":simple_console", "//tensorflow:tensorflow_py", + "//tensorflow/contrib/boosted_trees:boosted_trees_pip", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/graph_editor:graph_editor_pip", "//tensorflow/contrib/keras:keras", "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", "//tensorflow/contrib/ndlstm:ndlstm", "//tensorflow/contrib/nn:nn_py", + "//tensorflow/contrib/predictor:predictor_pip", "//tensorflow/contrib/session_bundle:session_bundle_pip", "//tensorflow/contrib/signal:signal_py", "//tensorflow/contrib/slim:slim", @@ -154,14 +164,19 @@ sh_binary( "//tensorflow/contrib/specs:specs", "//tensorflow/contrib/tensor_forest:init_py", "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip", + "//tensorflow/contrib/timeseries:timeseries_pip", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_helper_library", + "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/examples/tutorials/mnist:package", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:meta_graph_testdata", "//tensorflow/python:util_example_parser_configuration", "//tensorflow/python/debug:debug_pip", "//tensorflow/python/saved_model:saved_model", + "//tensorflow/python:spectral_ops_test_util", "//tensorflow/python/tools:tools_pip", - "//tensorflow/tensorboard", + "//tensorflow/tools/dist_test/server:grpc_tensorflow_server", ], }) + if_mkl(["//third_party/mkl:intel_binary_blob"]), ) diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 8b9a6b3de05c6639a85e1d0437cfa2639d142b92..ff7db52cb0e1ea48498e06f8c808373a1bfd2dce 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -17,6 +17,10 @@ set -e +function real_path() { + [[ $1 = /* ]] && echo "$1" || echo "$PWD/${1#./}" +} + function cp_external() { local src_dir=$1 local dest_dir=$2 @@ -41,7 +45,7 @@ function main() { exit 1 fi - DEST=$1 + DEST=$(real_path $1) TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) GPU_FLAG="" @@ -79,23 +83,6 @@ function main() { bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles \ "${TMPDIR}/external" RUNFILES=bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles/org_tensorflow - elif [ ! -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow ]; then - # Really old (0.2.1-) runfiles, without workspace name. - cp -R \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/tensorflow \ - "${TMPDIR}" - mkdir "${TMPDIR}/external" - cp_external \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/external \ - "${TMPDIR}/external" - RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles - # Copy MKL libs over so they can be loaded at runtime - if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then - mkdir "${TMPDIR}/_solib_k8" - cp -R \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \ - "${TMPDIR}/_solib_k8" - fi else if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external ]; then # Old-style runfiles structure (--legacy_external_runfiles). @@ -107,11 +94,13 @@ function main() { bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external \ "${TMPDIR}/external" # Copy MKL libs over so they can be loaded at runtime - if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then - mkdir "${TMPDIR}/_solib_k8" - cp -R \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \ - "${TMPDIR}/_solib_k8" + so_lib_dir="bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8" + if [ -d ${so_lib_dir} ]; then + mkl_so_dir=$(ls ${so_lib_dir} | grep mkl) + if [ $? -eq 0 ]; then + mkdir "${TMPDIR}/_solib_k8" + cp -R ${so_lib_dir}/${mkl_so_dir} "${TMPDIR}/_solib_k8" + fi fi else # New-style runfiles structure (--nolegacy_external_runfiles). @@ -124,11 +113,13 @@ function main() { bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \ "${TMPDIR}/external" # Copy MKL libs over so they can be loaded at runtime - if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl ]; then - mkdir "${TMPDIR}/_solib_k8" - cp -R \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8/_U_S_Sthird_Uparty_Smkl_Cintel_Ubinary_Ublob___Uthird_Uparty_Smkl \ - "${TMPDIR}/_solib_k8" + so_lib_dir="bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/_solib_k8" + if [ -d ${so_lib_dir} ]; then + mkl_so_dir=$(ls ${so_lib_dir} | grep mkl) + if [ $? -eq 0 ]; then + mkdir "${TMPDIR}/_solib_k8" + cp -R ${so_lib_dir}/${mkl_so_dir} "${TMPDIR}/_solib_k8" + fi fi fi RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow @@ -139,7 +130,7 @@ function main() { mkdir -p ${TMPDIR}/google mkdir -p ${TMPDIR}/third_party pushd ${RUNFILES%org_tensorflow} - for header in $(find protobuf -name \*.h); do + for header in $(find protobuf_archive -name \*.h); do mkdir -p "${TMPDIR}/google/$(dirname ${header})" cp "$header" "${TMPDIR}/google/$(dirname ${header})/" done diff --git a/tensorflow/tools/pip_package/check_load_py_test.py b/tensorflow/tools/pip_package/check_load_py_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7a132a8de3ea596a2fbd8e661308f1603778666b --- /dev/null +++ b/tensorflow/tools/pip_package/check_load_py_test.py @@ -0,0 +1,89 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests to check that py_test are properly loaded in BUILD files.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import subprocess + + +def check_output_despite_error(args): + """Get output of args from command line, even if there are errors. + + Args: + args: a list of command line args. + + Returns: + output as string. + """ + try: + output = subprocess.check_output(args, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + output = e.output + return output.strip() + + +def main(): + # Get all py_test target, note bazel query result will also include + # cuda_py_test etc. + try: + targets = subprocess.check_output( + 'bazel query "kind(py_test, //tensorflow/contrib/... + ' + '//tensorflow/python/... - ' + '//tensorflow/contrib/tensorboard/...)"', + shell=True).strip() + except subprocess.CalledProcessError as e: + targets = e.output + + # Only keep py_test targets, and filter out targets with 'no_pip' tag. + valid_targets = [] + for target in targets.split('\n'): + kind = check_output_despite_error(['buildozer', 'print kind', target]) + if kind == 'py_test': + tags = check_output_despite_error(['buildozer', 'print tags', target]) + if 'no_pip' not in tags: + valid_targets.append(target) + + # Get all BUILD files for all valid targets. + build_files = set() + for target in valid_targets: + build_files.add(os.path.join(target[2:].split(':')[0], 'BUILD')) + + # Check if BUILD files load py_test. + files_missing_load = [] + for build_file in build_files: + updated_build_file = subprocess.check_output( + 'buildozer -stdout "new_load //tensorflow:tensorflow.bzl py_test" ' + + build_file, + shell=True) + with open(build_file, 'r') as f: + if f.read() != updated_build_file: + files_missing_load.append(build_file) + + if files_missing_load: + raise RuntimeError('The following files are missing %s:\n %s' % ( + 'load("//tensorflow:tensorflow.bzl", "py_test").\nThis load statement' + ' is needed because otherwise pip tests will try to use their ' + 'dependencies, which are not visible to them.', + '\n'.join(files_missing_load))) + else: + print('TEST PASSED.') + + +if __name__ == '__main__': + main() diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index dec08157c2c3edf9e632227fb54a50abf3b1b49d..83909d83ae4c45404419745ef7982649e7f416f5 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -60,6 +60,12 @@ BLACKLIST = [ "//tensorflow/contrib/framework:checkpoint_ops_testdata", "//tensorflow/contrib/bayesflow:reinforce_simple_example", "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long + "//tensorflow/contrib/timeseries/examples:predict", + "//tensorflow/contrib/timeseries/examples:multivariate", + "//tensorflow/contrib/timeseries/examples:known_anomaly", + "//tensorflow/contrib/timeseries/examples:data/period_trend.csv", # pylint:disable=line-too-long + "//tensorflow/contrib/timeseries/python/timeseries:test_utils", + "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils", # pylint:disable=line-too-long ] diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 3bbc9000429a0c871470dc575e5ff8718a622378..0b0ee4c857224d239776e7a77b051221fc0d0a3b 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -29,17 +29,13 @@ from setuptools.dist import Distribution # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.2.0-rc2' +_VERSION = '1.3.0-rc0' REQUIRED_PACKAGES = [ 'numpy >= 1.11.0', 'six >= 1.10.0', - 'protobuf >= 3.2.0', - 'werkzeug >= 0.11.10', - 'html5lib == 0.9999999', # identical to 1.0b8 - 'markdown == 2.2.0', - 'bleach == 1.5.0', - 'backports.weakref == 1.0rc1', + 'protobuf >= 3.3.0', + 'tensorflow-tensorboard', ] project_name = 'tensorflow' @@ -57,9 +53,12 @@ else: # mock comes with unittest.mock for python3, need to install for python2 REQUIRED_PACKAGES.append('mock >= 2.0.0') +# weakref.finalize was introduced in Python 3.4 +if sys.version_info < (3, 4): + REQUIRED_PACKAGES.append('backports.weakref >= 1.0rc1') + # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ - 'tensorboard = tensorflow.tensorboard.tensorboard:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', ] # pylint: enable=line-too-long @@ -114,7 +113,7 @@ class InstallHeaders(Command): install_dir = os.path.join(self.install_dir, os.path.dirname(header)) # Get rid of some extra intervening directories so we can have fewer # directories for -I - install_dir = re.sub('/google/protobuf/src', '', install_dir) + install_dir = re.sub('/google/protobuf_archive/src', '', install_dir) # Copy eigen code into tensorflow/include. # A symlink would do, but the wheel file that gets created ignores @@ -165,7 +164,7 @@ else: headers = (list(find_files('*.h', 'tensorflow/core')) + list(find_files('*.h', 'tensorflow/stream_executor')) + - list(find_files('*.h', 'google/protobuf/src')) + + list(find_files('*.h', 'google/protobuf_archive/src')) + list(find_files('*', 'third_party/eigen3')) + list(find_files('*', 'external/eigen_archive'))) @@ -191,8 +190,6 @@ setup( package_data={ 'tensorflow': [ EXTENSION_NAME, - 'tensorboard/components/index.html', - 'tensorboard/TAG', ] + matches, }, zip_safe=False, diff --git a/tensorflow/tools/quantization/BUILD b/tensorflow/tools/quantization/BUILD index cb41185219c56f9a0d834a2e4b5b71c57b46810a..e99ad06a06294c4d037b76ea9450e51bd795e79d 100644 --- a/tensorflow/tools/quantization/BUILD +++ b/tensorflow/tools/quantization/BUILD @@ -13,7 +13,20 @@ py_library( name = "quantize_graph_lib", srcs = ["quantize_graph.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow:tensorflow_py"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:graph_util", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//third_party/py/numpy", + ], ) py_binary( @@ -27,18 +40,17 @@ py_binary( "//tensorflow/python:client", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:graph_util", "//tensorflow/python:platform", + "//tensorflow/python:tensor_util", "//third_party/py/numpy", - "@six_archive//:six", ], ) py_test( name = "quantize_graph_test", size = "small", - srcs = [ - "quantize_graph_test.py", - ], + srcs = ["quantize_graph_test.py"], srcs_version = "PY2AND3", tags = ["nomsan"], # http://b/32242946 deps = [ @@ -48,6 +60,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:graph_util", "//tensorflow/python:platform", "//third_party/py/numpy", ], @@ -55,12 +68,13 @@ py_test( py_binary( name = "graph_to_dot", - srcs = [ - "graph_to_dot.py", - ], + srcs = ["graph_to_dot.py"], main = "graph_to_dot.py", srcs_version = "PY2AND3", - deps = ["//tensorflow/python:platform"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + ], ) filegroup( diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD index 9367bcd4a3457d7387ee8dc17a4d19043fa8c9a2..28d651e9106b29058824c06b160df2b9b5781757 100644 --- a/tensorflow/tools/test/BUILD +++ b/tensorflow/tools/test/BUILD @@ -22,6 +22,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/core:protos_all_py", "//tensorflow/python:client", "//tensorflow/python:errors", "//tensorflow/python:platform", @@ -46,6 +47,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":system_info_lib", + "//tensorflow/core:protos_all_py", "//tensorflow/python:platform", ], ) @@ -54,8 +56,10 @@ py_binary( name = "run_and_gather_logs", srcs = ["run_and_gather_logs.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":run_and_gather_logs_lib", + "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", ], diff --git a/tensorflow/tools/test/performance.bzl b/tensorflow/tools/test/performance.bzl index 2956c6dde74ff38a7f000d6b6b595beaa397fa76..64fff844a70d439306c9bcf7f21d5a6047fa428a 100644 --- a/tensorflow/tools/test/performance.bzl +++ b/tensorflow/tools/test/performance.bzl @@ -28,7 +28,7 @@ def tf_cc_logged_benchmark( name = name, tags = all_tags, size = "large", - srcs = ["//tensorflow/tools/test:run_and_gather_logs.py"], + srcs = ["//tensorflow/tools/test:run_and_gather_logs"], args = [ "--name=//%s:%s" % (PACKAGE_NAME, name), "--test_name=" + target, diff --git a/tensorflow/tools/test/run_and_gather_logs_lib.py b/tensorflow/tools/test/run_and_gather_logs_lib.py index 570e09f1659526198a20db8cb87971a51f353d2b..c798dd5de7532d87387da598a1e7332370e41bed 100644 --- a/tensorflow/tools/test/run_and_gather_logs_lib.py +++ b/tensorflow/tools/test/run_and_gather_logs_lib.py @@ -135,7 +135,7 @@ def run_and_gather_logs(name, test_name, test_args, gpu_config = gpu_info_lib.gather_gpu_devices() if gpu_config: gpu_name = gpu_config[0].model - gpu_short_name_match = re.search(r"Tesla [KP][4,8]0", gpu_name) + gpu_short_name_match = re.search(r"Tesla (K40|K80|P100)", gpu_name) if gpu_short_name_match: gpu_short_name = gpu_short_name_match.group(0) test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_") diff --git a/tensorflow/tools/test/upload_test_benchmarks.py b/tensorflow/tools/test/upload_test_benchmarks.py index 829333e05629946fc5627d37301883d70572b1be..77cc9f75f7725438918f681833d58e9ecb4a2f70 100644 --- a/tensorflow/tools/test/upload_test_benchmarks.py +++ b/tensorflow/tools/test/upload_test_benchmarks.py @@ -162,7 +162,7 @@ def upload_benchmark_data(client, data): t_val.update({ "test": test_name, "start": start_time, - "info": unicode(test_result) + "info": unicode(data) }) batch.append(t_val) diff --git a/tensorflow/tools/test/upload_test_benchmarks_index.yaml b/tensorflow/tools/test/upload_test_benchmarks_index.yaml index 8cd33a1da60cad1c1a0e21998b4eefc81babfd8e..ec7f76f6663b3e586b4b63e92eb576740cd445f9 100644 --- a/tensorflow/tools/test/upload_test_benchmarks_index.yaml +++ b/tensorflow/tools/test/upload_test_benchmarks_index.yaml @@ -27,7 +27,7 @@ indexes: properties: - name: test - name: start - direction: asc + direction: desc # Index to access a specific (test, entry, start) Entity, and also to be able to # fetch a range of (start, timing) graph values for a given (test, entry) pair diff --git a/tensorflow/tools/tfprof/README.md b/tensorflow/tools/tfprof/README.md deleted file mode 100644 index 54f3cd62f283de853bb1b14e61c96c81f77702b5..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/README.md +++ /dev/null @@ -1,122 +0,0 @@ -# tfprof: TensorFlow Profiler and Beyond - -### Features - -* Profile model architectures - * parameters, tensor shapes, float operations, device placement, etc. -* Profile model performance - * execution time, memory consumption - * Profile multiple steps. -* Auto detect and advise. (Experimental) - -### Interfaces - -* Python API -* Command Line -* Visualization -* C++ API (Not public, contact us if needed.) - -### Views and Options - -tfprof provides 4 different views to organize the statistics. - - * code view: operations are grouped by Python codes that generate them. - * op view: operations are grouped by operation type (E.g. MatMul, Conv2D). - * scope view: operations are organized based on name scope hierarchies. - * graph view: operations are organized based on input/output. - -tfprof provides options to help user select, filter and order statistics. -See [Options](g3doc/options.md) for detail instructions. - -``` --max_depth 10 --min_bytes 0 --min_micros 0 --min_params 0 --min_float_ops 0 --min_occurrence 0 --step -1 --order_by name --account_type_regexes .* --start_name_regexes .* --trim_name_regexes --show_name_regexes .* --hide_name_regexes --account_displayed_op_only false --select params --output stdout: -``` - -### Tutorials - -* [Python API](g3doc/python_api.md) -* [Command Line Interface](g3doc/command_line.md) -* [Profile Time](g3doc/profile_time.md) -* [Profile Memory](g3doc/profile_memory.md) -* [Profile Model Architecture](g3doc/profile_model_architecture.md) -* [Auto Detect and Advise](g3doc/advise.md) -* [Options](g3doc/options.md) - -## Demo - -### Attribute TensorFlow graph running time to your Python codes. -```shell -tfprof> code -max_depth 1000 -show_name_regexes .*model_analyzer.*py.* -select micros -account_type_regexes .* -order_by micros -_TFProfRoot (0us/22.44ms) - model_analyzer_test.py:149:run_filename_as_m...:none (0us/22.44ms) - model_analyzer_test.py:33:_run_code_in_main:none (0us/22.44ms) - model_analyzer_test.py:208::test.main() (0us/22.44ms) - model_analyzer_test.py:132:testComplexCodeView:x = lib.BuildFull... (0us/22.44ms) - model_analyzer_testlib.py:63:BuildFullModel:return sgd_op.min... (0us/21.83ms) - model_analyzer_testlib.py:58:BuildFullModel:cell, array_ops.c... (0us/333us) - model_analyzer_testlib.py:54:BuildFullModel:seq.append(array_... (0us/254us) - model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0us/134us) - model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0us/40us) - ... - model_analyzer_testlib.py:61:BuildFullModel:loss = nn_ops.l2_... (0us/28us) - model_analyzer_testlib.py:60:BuildFullModel:target = array_op... (0us/0us) - model_analyzer_test.py:134:testComplexCodeView:sess.run(variable... (0us/0us) -``` - -### Show your model variables and the number of parameters. -``` -tfprof> scope -account_type_regexes VariableV2 -max_depth 4 -select params -_TFProfRoot (--/930.58k params) - global_step (1/1 params) - init/init_conv/DW (3x3x3x16, 432/864 params) - pool_logit/DW (64x10, 640/1.28k params) - pool_logit/DW/Momentum (64x10, 640/640 params) - pool_logit/biases (10, 10/20 params) - pool_logit/biases/Momentum (10, 10/10 params) - unit_last/final_bn/beta (64, 64/128 params) - unit_last/final_bn/gamma (64, 64/128 params) - unit_last/final_bn/moving_mean (64, 64/64 params) - unit_last/final_bn/moving_variance (64, 64/64 params) -``` - -### Show the most expensive operation types. -``` -tfprof> op -select micros,bytes,occurrence -order_by micros -SoftmaxCrossEntropyWithLogits 36.58MB (100.00%, 0.05%), 1.37sec (100.00%, 23.56%), 30 -MatMul 2720.57MB (99.95%, 3.66%), 988.90ms (76.44%, 17.05%), 3450 -ConcatV2 741.37MB (96.29%, 1.00%), 421.44ms (59.38%, 7.27%), 6098 -Mul 3957.24MB (95.29%, 5.33%), 418.90ms (52.12%, 7.22%), 9427 -Add 740.05MB (89.96%, 1.00%), 335.26ms (44.89%, 5.78%), 2180 -Sub 32.46MB (88.97%, 0.04%), 216.44ms (39.11%, 3.73%), 4372 -AddN 733.21MB (88.92%, 0.99%), 208.46ms (35.38%, 3.59%), 5481 -Slice 708.07MB (87.94%, 0.95%), 205.27ms (31.78%, 3.54%), 7277 -Fill 954.27MB (86.98%, 1.28%), 154.50ms (28.24%, 2.66%), 9686 -Select 312.33MB (85.70%, 0.42%), 123.04ms (25.58%, 2.12%), 5746 -Sigmoid 152.57MB (85.28%, 0.21%), 96.66ms (23.46%, 1.67%), 2970 -``` - -### Visualize time and memory. - -[CodeTimeline](g3doc/graph_timeline.png) - - -### Teams - -* Xin Pan (xpan@google.com, github: panyx0718) -* Jon Shlens -* Yao Zhang diff --git a/tensorflow/tools/tfprof/g3doc/advise.md b/tensorflow/tools/tfprof/g3doc/advise.md deleted file mode 100644 index 3bce6270ff8368fb57d183c6f4c6a88f5dd5bc07..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/g3doc/advise.md +++ /dev/null @@ -1,44 +0,0 @@ -## Auto Detect and Advise - -tfprof analyzes profiles and generates advises for common issues. - -### Run Advise. -```python -# First create a profiler. See profiler tutorials for more details. -profiler = model_analyzer.Profiler(sess.graph) -run_meta = config_pb2.RunMetadata() -_ = sess.run(r1, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE), - run_metadata=run_meta) -profiler.add_step(1, run_meta) - -# Start advise. -profiler.advise() -``` - -### Checker - -There is no magic behind advise mode. tfprof builds the profiles first, then -it runs through a list of `Checkers`, each one responsible for checking one -area with the profile and report issues. A `Checker` is like a plugin. - -For example: - -####JobChecker (Not Available OSS) -* Checking RecvTensor RPC latency and bandwidth. -* Checking CPU/Memory utilization of the job. - -####AcceleratorUtilization Checker -* Checks what percentage of time the accelerator spends on computation. - -####Operation Checker -* Check whether the operation runs with optimal options. -* Checks if there is a better implementation to replace the current operation. - -####Contribute Your Checker - -Follow examples of accelerator_utilization_checker.h - - - diff --git a/tensorflow/tools/tfprof/g3doc/options.md b/tensorflow/tools/tfprof/g3doc/options.md deleted file mode 100644 index 78c72bf5eddab24dc6e967adf8ef5c4a82c0b98f..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/g3doc/options.md +++ /dev/null @@ -1,86 +0,0 @@ -##Options - -###Overview - -For all tfprof views, the statistics are processed with the following procedures - -1) An in-memory data structure is used represent the view. - -2) `-account_type_regexes` is used to first select the operations that match - the specified operation types. An operation has its default type - (e.g. MatMul, Conv2D). `tfprof` also considers device as operation type. - User can also define customized operation type. Hence, an operation has - multiple types. Operations with matched - types are selected for display and their statistics are aggregated - by the in-memory data structure. - -3) Various `-xxx_name_regexes`, `-min_xxx`, `-max_depth` etc options are then - applied to further filter based on names and values. - It's no limited operation name. In code view, - it's the code trace. In op view, it's the operation type name. Different - from `-account_type_regexes`, Statistics are used even if a name is not displayed. - For example, in code view, a callee might be hidden, but its statistics is - still aggregated by it's caller. `-account_displayed_op_only`, however, - breaks the rule and only use statistics of displayed names. - -4) Finally, the filtered data structure is displayed in a format depending - on the `-output` option. - -####Option Semantics In Different View -options usually have the same semantics in different views. However, some -can vary. For example `-max_depth` in scope view means the depth of -name scope tree. In op view, it means the length of operation list. -In graph view, in means the number of hops in the graph. - - -###Docs - -`-max_depth`: Show ops that are at most this number of hops from starting op in the tree/graph structure. - -`-min_bytes`: Show ops that request at least this number of bytes. - -`-min_micros`: Show ops that spend at least this number of microseconds to run. - -`-min_params`: Show ops that contains at least this number of parameters. - -`-min_float_ops`: Show ops that contain at least this number of float operations. Only available if an op has op.RegisterStatistics() defined and OpLog is provided - -`-min_occurrence`: Show ops that appear at least this number of times. Only available in "op" view. - -`-step`: Show the stats of the this step when multiple steps of RunMetadata were added. By default, show the average of all steps." - -`-order_by`: Order the results by [name|depth|bytes|micros|params|float_ops|occurrence] - -`-account_type_regexes`: Account and display the ops whose types match one of the type regexes specified. tfprof allow user to define extra op types for ops through tensorflow.tfprof.OpLog proto. regexes are comma-sperated. - -`-start_name_regexes`: Show ops starting from the ops that matches the regexes, recursively. regexes are comma-separated. - -`-trim_name_regexes`: Hide ops starting from the ops that matches the regexes, recursively, regexes are comma-seprated. - -`-show_name_regexes`: Show ops that match the regexes. regexes are comma-seprated. - -`-hide_name_regexes`: Hide ops that match the regexes. regexes are comma-seprated. - -Notes: For each op, `-account_type_regexes` is first evaluated, only ops with -types matching the specified regexes are accounted and selected for displayed. -`-start/trim/show/hide_name_regexes` are used to further filter ops for display. -`-start_name_regexes` is evaluated first to search the starting ops to display. -Descendants of starting ops are then evaluated against `-show/hide_name_regexes` -to make display decision. If an op matches trim_name_regexes, all its -descendants are hidden. Ops statistics are *accounted even if they are hidden* -as long as they match the `-account_xxx` options. - -`-account_displayed_op_only`: If True, only account the statistics of ops eventually displayed. If False, account all op statistics matching -account_type_regexes recursively. - -`-select`: Comma-separated list of metrics to show: -[bytes|micros|params|float_ops|occurrence|tensor_value|device|op_types|input_shapes]. - -`-output`: Output results as stdout, file or timeline. -The format is ```output_type:key=value,key=value```. -For example: ```-output timeline:outfile=```. - -```shell -timeline: key=outfile, value=. -stdout: none. -file: key=outfile, value=. -``` diff --git a/tensorflow/tools/tfprof/internal/advisor/checker.h b/tensorflow/tools/tfprof/internal/advisor/checker.h deleted file mode 100644 index b8b057be5b1d6410acfb3e8607693e303a6f963c..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/internal/advisor/checker.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_CHECKER_H_ - -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/tools/tfprof/internal/tfprof_stats.h" - -namespace tensorflow { -namespace tfprof { - -static const char* const kLevel[] = { - "NOTE", // Good to know. - "SUGGEST", // Might get better. - "WARN", // Please do it for better. -}; - -class Checker { - public: - virtual ~Checker(){}; - - virtual string name() = 0; - - std::vector Run(const TFStats* stats) { return Check(stats); } - - protected: - // Returns a vector of string, each one being an advice. - virtual std::vector Check(const TFStats* stats) = 0; -}; -} // namespace tfprof -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_CHECKER_H_ diff --git a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h deleted file mode 100644 index 856f51545921283799a87e053c96a19d0ee4387d..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ -#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ - -#include "tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h" -#include "tensorflow/tools/tfprof/internal/advisor/checker.h" -#include "tensorflow/tools/tfprof/internal/advisor/internal_checker_runner.h" -#include "tensorflow/tools/tfprof/internal/advisor/operation_checker.h" - -namespace tensorflow { -namespace tfprof { - -// The Advisor runs a list of Checkers, each checks a specific area. -class Advisor { - public: - Advisor(const TFStats* stats) : stats_(stats) {} - - std::map> Advise() { - // Note: Release a checker's memory ASAP. - std::map> reports = RunInternalCheckers(stats_); - // TODO(xpan): Think of a way to turn off/on specific checkers. - AcceleratorUtilizationChecker au_checker; - reports[au_checker.name()] = au_checker.Run(stats_); - OperationChecker op_checker; - reports[op_checker.name()] = op_checker.Run(stats_); - - for (const auto& checker_r : reports) { - fprintf(stdout, "%s reports:\n", checker_r.first.c_str()); - for (const auto& r : checker_r.second) { - fprintf(stdout, "%s\n", r.c_str()); - } - } - fflush(stdout); - return reports; - } - - private: - const TFStats* stats_; -}; - -} // namespace tfprof -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc b/tensorflow/tools/tfprof/internal/tfprof_show_test.cc deleted file mode 100644 index 498477de0a00f828b07a6a955e05722d6a79d433..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/tools/tfprof/internal/tfprof_stats.h" - -#include - -#include "tensorflow/c/checkpoint_reader.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/tools/tfprof/internal/tfprof_constants.h" -#include "tensorflow/tools/tfprof/internal/tfprof_options.h" -#include "tensorflow/tools/tfprof/internal/tfprof_utils.h" -#include "tensorflow/tools/tfprof/tfprof_log.pb.h" -#include "tensorflow/tools/tfprof/tfprof_output.pb.h" - -namespace tensorflow { -namespace tfprof { -class TFProfShowTest : public ::testing::Test { - protected: - TFProfShowTest() { - string graph_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/graph.pbtxt"); - std::unique_ptr graph_pb(new tensorflow::GraphDef()); - TF_CHECK_OK(ReadGraphDef(Env::Default(), graph_path, graph_pb.get())); - - std::unique_ptr run_meta_pb( - new tensorflow::RunMetadata()); - string run_meta_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/run_meta"); - TF_CHECK_OK( - ReadBinaryProto(Env::Default(), run_meta_path, run_meta_pb.get())); - - std::unique_ptr op_log_pb(new OpLog()); - string op_log_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/tfprof_log"); - TF_CHECK_OK(ReadBinaryProto(Env::Default(), op_log_path, op_log_pb.get())); - - string ckpt_path = io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/ckpt"); - TF_Status* status = TF_NewStatus(); - std::unique_ptr ckpt_reader( - new checkpoint::CheckpointReader(ckpt_path, status)); - CHECK(TF_GetCode(status) == TF_OK); - TF_DeleteStatus(status); - - tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb), - std::move(op_log_pb), std::move(ckpt_reader))); - } - - std::unique_ptr tf_stats_; -}; - -TEST_F(TFProfShowTest, DumpScopeMode) { - string dump_file = io::JoinPath(testing::TmpDir(), "dump"); - Options opts(5, 0, 0, 0, 0, 0, -1, "name", - {"VariableV2"}, // accout_type_regexes - {".*"}, {""}, {".*"}, {""}, false, - {"params", "bytes", "micros", "float_ops"}, "file", - {{"outfile", dump_file}}); - tf_stats_->ShowGraphNode("scope", opts); - - string dump_str; - TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); - EXPECT_EQ( - "node name | # parameters | # float_ops | output bytes | execution " - "time\n_TFProfRoot (--/370 params, --/0 flops, --/1.48KB, --/5us)\n " - "conv2d (--/140 params, --/0 flops, --/560B, --/2us)\n conv2d/bias " - "(5, 5/5 params, 0/0 flops, 20B/20B, 1us/1us)\n conv2d/kernel " - "(3x3x3x5, 135/135 params, 0/0 flops, 540B/540B, 1us/1us)\n conv2d_1 " - "(--/230 params, --/0 flops, --/920B, --/3us)\n conv2d_1/bias (5, 5/5 " - "params, 0/0 flops, 20B/20B, 1us/1us)\n conv2d_1/kernel (3x3x5x5, " - "225/225 params, 0/0 flops, 900B/900B, 2us/2us)\n", - dump_str); -} - -TEST_F(TFProfShowTest, DumpOpMode) { - string dump_file = io::JoinPath(testing::TmpDir(), "dump"); - Options opts( - 5, 0, 0, 0, 0, 4, -1, "params", {".*"}, // accout_type_regexes - {".*"}, {""}, {".*"}, {""}, false, - {"params", "bytes", "micros", "float_ops", "occurrence", "input_shapes"}, - "file", {{"outfile", dump_file}}); - tf_stats_->ShowMultiGraphNode("op", opts); - - string dump_str; - TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); - EXPECT_EQ( - "nodename|outputbytes|executiontime|#parameters|#float_ops|opoccurrence|" - "inputshapes\nVariableV21.48KB(100.00%,17.10%),5us(100.00%,5.15%)," - "370params(100.00%,100.00%),0float_ops(100.00%,0.00%),4\n\ninput_type:\t(" - "*4)\texec_time:5us\n\nAssign0B(0.00%,0.00%),0us(94.85%,0.00%),0params(0." - "00%,0.00%),0float_ops(100.00%,0.00%),8\n\ninput_type:0:unknown,\t1:" - "unknown\t(*8)\texec_time:0us\n\nConst1.54KB(58.87%,17.74%),1us(80.41%,1." - "03%),0params(0.00%,0.00%),0float_ops(98.49%,0.00%),24\n\ninput_type:\t(*" - "24)\texec_time:1us\n\n", - StringReplace(dump_str, " ", "")); -} -} // namespace tfprof -} // namespace tensorflow diff --git a/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc b/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc deleted file mode 100644 index a1e500f94929ce4b04c2ea2aabb4b6e13acd2202..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc +++ /dev/null @@ -1,247 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/tools/tfprof/internal/tfprof_stats.h" - -#include - -#include "tensorflow/c/checkpoint_reader.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/tools/tfprof/internal/tfprof_constants.h" -#include "tensorflow/tools/tfprof/internal/tfprof_options.h" -#include "tensorflow/tools/tfprof/internal/tfprof_utils.h" -#include "tensorflow/tools/tfprof/tfprof_log.pb.h" -#include "tensorflow/tools/tfprof/tfprof_output.pb.h" - -namespace tensorflow { -namespace tfprof { -class TFProfStatsTest : public ::testing::Test { - protected: - TFProfStatsTest() { - string graph_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/graph.pbtxt"); - std::unique_ptr graph_pb(new tensorflow::GraphDef()); - TF_CHECK_OK(ReadGraphDef(Env::Default(), graph_path, graph_pb.get())); - - std::unique_ptr run_meta_pb( - new tensorflow::RunMetadata()); - string run_meta_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/run_meta"); - TF_CHECK_OK( - ReadBinaryProto(Env::Default(), run_meta_path, run_meta_pb.get())); - - std::unique_ptr op_log_pb(new OpLog()); - string op_log_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/tfprof_log"); - TF_CHECK_OK(ReadBinaryProto(Env::Default(), op_log_path, op_log_pb.get())); - - string ckpt_path = io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/ckpt"); - TF_Status* status = TF_NewStatus(); - std::unique_ptr ckpt_reader( - new checkpoint::CheckpointReader(ckpt_path, status)); - CHECK(TF_GetCode(status) == TF_OK); - TF_DeleteStatus(status); - - tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb), - std::move(op_log_pb), std::move(ckpt_reader))); - } - - std::unique_ptr tf_stats_; -}; - -TEST_F(TFProfStatsTest, CustomOpType) { - Options opts(3, 0, 0, 0, 0, 0, -1, "name", - {kTrainableVarType}, // accout_type_regexes - {".*"}, {""}, {".*"}, {""}, false, - {"params", "bytes", "micros", "float_ops"}, "", {}); - const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts); - - TFGraphNodeProto expected; - CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " - "0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: " - "370\nchildren {\n name: \"conv2d\"\n exec_micros: 0\n " - "requested_bytes: 0\n total_exec_micros: 2\n total_requested_bytes: " - "560\n total_parameters: 140\n children {\n name: \"conv2d/bias\"\n " - " exec_micros: 1\n requested_bytes: 20\n parameters: 5\n " - "total_exec_micros: 1\n total_requested_bytes: 20\n " - "total_parameters: 5\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n children {\n name: \"conv2d/kernel\"\n " - "exec_micros: 1\n requested_bytes: 540\n parameters: 135\n " - "total_exec_micros: 1\n total_requested_bytes: 540\n " - "total_parameters: 135\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n float_ops: 0\n total_float_ops: " - "0\n}\nchildren {\n name: \"conv2d_1\"\n exec_micros: 0\n " - "requested_bytes: 0\n total_exec_micros: 3\n total_requested_bytes: " - "920\n total_parameters: 230\n children {\n name: " - "\"conv2d_1/bias\"\n exec_micros: 1\n requested_bytes: 20\n " - "parameters: 5\n total_exec_micros: 1\n total_requested_bytes: " - "20\n total_parameters: 5\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n children {\n name: \"conv2d_1/kernel\"\n " - " exec_micros: 2\n requested_bytes: 900\n parameters: 225\n " - "total_exec_micros: 2\n total_requested_bytes: 900\n " - "total_parameters: 225\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n float_ops: 0\n total_float_ops: " - "0\n}\nfloat_ops: 0\ntotal_float_ops: 0\n", - &expected)); - EXPECT_EQ(expected.DebugString(), root.DebugString()); -} - -TEST_F(TFProfStatsTest, CheckPointOpType) { - Options opts(3, 0, 0, 0, 0, 0, -1, "name", - {kCkptVarType}, // accout_type_regexes - {".*"}, {""}, {".*"}, {""}, false, - {"params", "bytes", "micros", "float_ops"}, "", {}); - const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts); - - TFGraphNodeProto expected; - CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " - "0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: " - "370\nchildren {\n name: \"conv2d\"\n exec_micros: 0\n " - "requested_bytes: 0\n total_exec_micros: 2\n total_requested_bytes: " - "560\n total_parameters: 140\n children {\n name: \"conv2d/bias\"\n " - " exec_micros: 1\n requested_bytes: 20\n parameters: 5\n " - "total_exec_micros: 1\n total_requested_bytes: 20\n " - "total_parameters: 5\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n children {\n name: \"conv2d/kernel\"\n " - "exec_micros: 1\n requested_bytes: 540\n parameters: 135\n " - "total_exec_micros: 1\n total_requested_bytes: 540\n " - "total_parameters: 135\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n float_ops: 0\n total_float_ops: " - "0\n}\nchildren {\n name: \"conv2d_1\"\n exec_micros: 0\n " - "requested_bytes: 0\n total_exec_micros: 3\n total_requested_bytes: " - "920\n total_parameters: 230\n children {\n name: " - "\"conv2d_1/bias\"\n exec_micros: 1\n requested_bytes: 20\n " - "parameters: 5\n total_exec_micros: 1\n total_requested_bytes: " - "20\n total_parameters: 5\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n children {\n name: \"conv2d_1/kernel\"\n " - " exec_micros: 2\n requested_bytes: 900\n parameters: 225\n " - "total_exec_micros: 2\n total_requested_bytes: 900\n " - "total_parameters: 225\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n float_ops: 0\n total_float_ops: " - "0\n}\nfloat_ops: 0\ntotal_float_ops: 0\n", - &expected)); - EXPECT_EQ(expected.DebugString(), root.DebugString()); -} - -TEST_F(TFProfStatsTest, TestGraph) { - Options opts(100, 0, 10000, 0, 0, 0, -1, "name", {".*"}, - {"cost.*"}, // start_name_regexes - {""}, {".*"}, {""}, false, - {"params", "bytes", "micros", "float_ops"}, "", {}); - const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("graph", opts); - - TFGraphNodeProto expected; - CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: 0\n" - "total_exec_micros: 97\ntotal_requested_bytes: " - "8656\ntotal_parameters: 370\nfloat_ops: " - "0\ntotal_float_ops: 34360\n", - &expected)); - EXPECT_EQ(expected.DebugString(), root.DebugString()); -} - -TEST_F(TFProfStatsTest, TestFloatOps) { - Options opts(10, 0, 0, 0, 1, 0, -1, "name", {".*"}, {".*"}, {""}, {".*"}, - {""}, false, {"float_ops"}, "", {}); - const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts); - - TFGraphNodeProto expected; - CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " - "0\ntotal_exec_micros: 97\ntotal_requested_bytes: " - "8656\ntotal_parameters: 370\nchildren {\n name: \"conv2d/BiasAdd\"\n " - "exec_micros: 12\n requested_bytes: 1440\n total_exec_micros: 12\n " - "total_requested_bytes: 1440\n total_parameters: 0\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 360\n " - "total_float_ops: 360\n input_shapes {\n key: 0\n value {\n " - "unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n value " - "{\n unknown_rank: true\n }\n }\n}\nchildren {\n name: " - "\"conv2d/convolution\"\n exec_micros: 60\n requested_bytes: 1440\n " - "total_exec_micros: 60\n total_requested_bytes: 1440\n " - "total_parameters: 0\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 19440\n " - "total_float_ops: 19440\n input_shapes {\n key: 0\n value {\n " - " unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n " - "value {\n unknown_rank: true\n }\n }\n}\nchildren {\n name: " - "\"conv2d_2/BiasAdd\"\n exec_micros: 2\n requested_bytes: 640\n " - "total_exec_micros: 2\n total_requested_bytes: 640\n total_parameters: " - "0\n devices: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: " - "160\n total_float_ops: 160\n input_shapes {\n key: 0\n value " - "{\n unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n " - " value {\n unknown_rank: true\n }\n }\n}\nchildren {\n " - "name: \"conv2d_2/convolution\"\n exec_micros: 13\n requested_bytes: " - "640\n total_exec_micros: 13\n total_requested_bytes: 640\n " - "total_parameters: 0\n devices: " - "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 14400\n " - "total_float_ops: 14400\n input_shapes {\n key: 0\n value {\n " - " unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n " - "value {\n unknown_rank: true\n }\n }\n}\nfloat_ops: " - "0\ntotal_float_ops: 34360\n", - &expected)); - EXPECT_EQ(expected.DebugString(), root.DebugString()); -} - -TEST_F(TFProfStatsTest, TestAccountShownNameOnly) { - Options opts(100, 0, 0, 0, 0, 0, -1, "name", {".*"}, {".*"}, {""}, - {"unit_2_1.*DW"}, // show_name_regexes. - {""}, true, // account_displayed_op_only. - {"params"}, "", {}); - const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts); - - TFGraphNodeProto expected; - CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " - "0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: " - "0\nfloat_ops: 0\ntotal_float_ops: 0\n", - &expected)); - EXPECT_EQ(expected.DebugString(), root.DebugString()); -} - -TEST_F(TFProfStatsTest, TestShowTensorValue) { - Options opts(10, 0, 0, 0, 0, 0, -1, "name", {".*"}, {".*"}, {""}, - {"unit_1_0.*gamma"}, {""}, false, - {"tensor_value"}, // Show tensor value from checkpoint. - "", {}); - const TFGraphNodeProto& root = tf_stats_->ShowGraphNode("scope", opts); - TFGraphNodeProto expected; - CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " - "0\ntotal_exec_micros: 97\ntotal_requested_bytes: " - "8656\ntotal_parameters: 370\nfloat_ops: 0\ntotal_float_ops: 34360\n", - &expected)); - EXPECT_EQ(expected.DebugString(), root.DebugString()); -} - -} // namespace tfprof -} // namespace tensorflow diff --git a/tensorflow/tools/tfprof/tfprof_main.cc b/tensorflow/tools/tfprof/tfprof_main.cc deleted file mode 100644 index ae02b526347474e1aa738ee1a84cfabaeb7d723c..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/tfprof_main.cc +++ /dev/null @@ -1,286 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "linenoise.h" -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/checkpoint_reader.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/tools/tfprof/internal/tfprof_options.h" -#include "tensorflow/tools/tfprof/internal/tfprof_stats.h" -#include "tensorflow/tools/tfprof/internal/tfprof_utils.h" -#include "tensorflow/tools/tfprof/tfprof_log.pb.h" - -using tensorflow::str_util::Split; - -void completion(const char* buf, linenoiseCompletions* lc) { - tensorflow::string buf_str = buf; - if (buf_str.find(" ") == buf_str.npos) { - for (const char* opt : tensorflow::tfprof::kCmds) { - if (tensorflow::string(opt).find(buf_str) == 0) { - linenoiseAddCompletion(lc, opt); - } - } - return; - } - - tensorflow::string prefix; - int last_dash = buf_str.find_last_of(' '); - if (last_dash != tensorflow::string::npos) { - prefix = buf_str.substr(0, last_dash + 1); - buf_str = buf_str.substr(last_dash + 1, tensorflow::kint32max); - } - for (const char* opt : tensorflow::tfprof::kOptions) { - if (tensorflow::string(opt).find(buf_str) == 0) { - linenoiseAddCompletion(lc, (prefix + opt).c_str()); - } - } -} - -int main(int argc, char** argv) { - tensorflow::string FLAGS_graph_path = ""; - tensorflow::string FLAGS_run_meta_path = ""; - tensorflow::string FLAGS_op_log_path = ""; - tensorflow::string FLAGS_checkpoint_path = ""; - tensorflow::int32 FLAGS_max_depth = 10; - tensorflow::int64 FLAGS_min_bytes = 0; - tensorflow::int64 FLAGS_min_micros = 0; - tensorflow::int64 FLAGS_min_params = 0; - tensorflow::int64 FLAGS_min_float_ops = 0; - tensorflow::int64 FLAGS_min_occurrence = 0; - tensorflow::int64 FLAGS_step = -1; - tensorflow::string FLAGS_order_by = "name"; - tensorflow::string FLAGS_account_type_regexes = ".*"; - tensorflow::string FLAGS_start_name_regexes = ".*"; - tensorflow::string FLAGS_trim_name_regexes = ""; - tensorflow::string FLAGS_show_name_regexes = ".*"; - tensorflow::string FLAGS_hide_name_regexes; - bool FLAGS_account_displayed_op_only = false; - tensorflow::string FLAGS_select = "params"; - tensorflow::string FLAGS_output = ""; - for (int i = 0; i < argc; i++) { - fprintf(stderr, "%s\n", argv[i]); - } - - std::vector flag_list = { - tensorflow::Flag("graph_path", &FLAGS_graph_path, - "GraphDef proto text file name"), - tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path, - "Comma-separated list of RunMetadata proto binary " - "files. Each file is given step number 0,1,2,etc"), - tensorflow::Flag("op_log_path", &FLAGS_op_log_path, - "tensorflow::tfprof::OpLog proto binary file name"), - tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path, - "TensorFlow Checkpoint file name"), - tensorflow::Flag("max_depth", &FLAGS_max_depth, "max depth"), - tensorflow::Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"), - tensorflow::Flag("min_micros", &FLAGS_min_micros, "min micros"), - tensorflow::Flag("min_params", &FLAGS_min_params, "min params"), - tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"), - tensorflow::Flag("min_occurrence", &FLAGS_min_occurrence, - "min occurrence"), - tensorflow::Flag("step", &FLAGS_step, - "The stats of which step to use. By default average"), - tensorflow::Flag("order_by", &FLAGS_order_by, "order by"), - tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes, - "start name regexes"), - tensorflow::Flag("trim_name_regexes", &FLAGS_trim_name_regexes, - "trim name regexes"), - tensorflow::Flag("show_name_regexes", &FLAGS_show_name_regexes, - "show name regexes"), - tensorflow::Flag("hide_name_regexes", &FLAGS_hide_name_regexes, - "hide name regexes"), - tensorflow::Flag("account_displayed_op_only", - &FLAGS_account_displayed_op_only, - "account displayed op only"), - tensorflow::Flag("select", &FLAGS_select, "select"), - tensorflow::Flag("output", &FLAGS_output, "output"), - }; - tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_ok) { - printf("%s", usage.c_str()); - return (2); - } - tensorflow::port::InitMain(argv[0], &argc, &argv); - - fprintf(stderr, "%s\n", FLAGS_graph_path.c_str()); - - std::vector account_type_regexes = - Split(FLAGS_account_type_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector start_name_regexes = - Split(FLAGS_start_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector trim_name_regexes = - Split(FLAGS_trim_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector show_name_regexes = - Split(FLAGS_show_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector hide_name_regexes = - Split(FLAGS_hide_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector select = - Split(FLAGS_select, ',', tensorflow::str_util::SkipEmpty()); - - tensorflow::string output_type; - std::map output_options; - tensorflow::Status s = tensorflow::tfprof::ParseOutput( - FLAGS_output, &output_type, &output_options); - CHECK(s.ok()) << s.ToString(); - - tensorflow::string cmd = ""; - if (argc == 1 && FLAGS_graph_path.empty()) { - printf("1) go/tfprof: Tutorial.\n"); - printf("2) tfprof help: Detail help information.\n"); - printf( - "3) tfprof --graph_path : " - "Profiling model structure, tensor shape and # parameters.\n"); - printf( - "4) tfprof --graph_path \\\n" - " --run_meta_path \\\n" - " --op_log_path " - "\\\n" - " --checkpoint_path : " - "Profiling everything!\n"); - return 0; - } else if (argc > 1) { - if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[5]) { - tensorflow::tfprof::PrintHelp(); - return 0; - } - if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[0] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[1] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[2] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[3]) { - cmd = argv[1]; - } - } - - printf("Reading Files...\n"); - std::unique_ptr graph(new tensorflow::GraphDef()); - TF_CHECK_OK(tensorflow::tfprof::ReadGraphDef(tensorflow::Env::Default(), - FLAGS_graph_path, graph.get())); - - std::unique_ptr op_log( - new tensorflow::tfprof::OpLog()); - if (!FLAGS_op_log_path.empty()) { - tensorflow::string op_log_str; - s = tensorflow::ReadFileToString(tensorflow::Env::Default(), - FLAGS_op_log_path, &op_log_str); - if (!s.ok()) { - fprintf(stderr, "Failed to read op_log_path: %s\n", s.ToString().c_str()); - return 1; - } - if (!tensorflow::ParseProtoUnlimited(op_log.get(), op_log_str)) { - fprintf(stderr, "Failed to parse op_log_path\n"); - return 1; - } - } - - std::unique_ptr ckpt_reader; - TF_Status* status = TF_NewStatus(); - if (!FLAGS_checkpoint_path.empty()) { - ckpt_reader.reset(new tensorflow::checkpoint::CheckpointReader( - FLAGS_checkpoint_path, status)); - if (TF_GetCode(status) != TF_OK) { - fprintf(stderr, "%s\n", TF_Message(status)); - TF_DeleteStatus(status); - return 1; - } - TF_DeleteStatus(status); - } - - tensorflow::tfprof::TFStats tf_stat( - std::move(graph), nullptr, std::move(op_log), std::move(ckpt_reader)); - - std::vector run_meta_files = - Split(FLAGS_run_meta_path, ',', tensorflow::str_util::SkipEmpty()); - for (int i = 0; i < run_meta_files.size(); ++i) { - std::unique_ptr run_meta( - new tensorflow::RunMetadata()); - s = ReadBinaryProto(tensorflow::Env::Default(), run_meta_files[i], - run_meta.get()); - if (!s.ok()) { - fprintf(stderr, "Failed to read run_meta_path %s. Status: %s\n", - run_meta_files[i].c_str(), s.ToString().c_str()); - return 1; - } - tf_stat.ParseRunMeta(i, std::move(run_meta)); - } - - tensorflow::tfprof::Options opts( - FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros, FLAGS_min_params, - FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by, - account_type_regexes, start_name_regexes, trim_name_regexes, - show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only, - select, output_type, output_options); - - if (cmd == tensorflow::tfprof::kCmds[2] || - cmd == tensorflow::tfprof::kCmds[3]) { - tf_stat.ShowMultiGraphNode(cmd, opts); - return 0; - } else if (cmd == tensorflow::tfprof::kCmds[0] || - cmd == tensorflow::tfprof::kCmds[1]) { - tf_stat.ShowGraphNode(cmd, opts); - return 0; - } - - linenoiseSetCompletionCallback(completion); - linenoiseHistoryLoad(".tfprof_history.txt"); - - for (char* line = nullptr; (line = linenoise("tfprof> ")) != nullptr;) { - tensorflow::string line_s = line; - free(line); - - if (line_s.empty()) { - printf("%s", opts.ToString().c_str()); - continue; - } - linenoiseHistoryAdd(line_s.c_str()); - linenoiseHistorySave(".tfprof_history.txt"); - - tensorflow::tfprof::Options new_opts = opts; - tensorflow::Status s = - tensorflow::tfprof::ParseCmdLine(line_s, &cmd, &new_opts); - if (!s.ok()) { - fprintf(stderr, "E: %s\n", s.ToString().c_str()); - continue; - } - if (cmd == tensorflow::tfprof::kCmds[4]) { - opts = new_opts; - } else if (cmd == tensorflow::tfprof::kCmds[5]) { - tensorflow::tfprof::PrintHelp(); - } else if (cmd == tensorflow::tfprof::kCmds[2] || - cmd == tensorflow::tfprof::kCmds[3]) { - tf_stat.ShowMultiGraphNode(cmd, new_opts); - } else if (cmd == tensorflow::tfprof::kCmds[0] || - cmd == tensorflow::tfprof::kCmds[1]) { - tf_stat.ShowGraphNode(cmd, new_opts); - } - } - return 0; -} diff --git a/tensorflow/tools/tfprof/tfprof_options.proto b/tensorflow/tools/tfprof/tfprof_options.proto deleted file mode 100644 index 27eafb1ca9c27a8f03324bf95b31715014d5d95b..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/tfprof_options.proto +++ /dev/null @@ -1,26 +0,0 @@ -syntax = "proto2"; - -package tensorflow.tfprof; - -// Refers to tfprof_options.h/cc for documentation. -// Only used to pass tfprof options from Python to C++. -message OptionsProto { - optional int64 max_depth = 1; - optional int64 min_bytes = 2; - optional int64 min_micros = 3; - optional int64 min_params = 4; - optional int64 min_float_ops = 5; - optional int64 min_occurrence = 17; - optional int64 step = 18 [default = -1]; - - optional string order_by = 7; - repeated string account_type_regexes = 8; - repeated string start_name_regexes = 9; - repeated string trim_name_regexes = 10; - repeated string show_name_regexes = 11; - repeated string hide_name_regexes = 12; - optional bool account_displayed_op_only = 13; - repeated string select = 14; - optional string output = 15; - optional string dump_to_file = 16; -} diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index de5525b0f208821594e06ca9cf3b029838afbdab..c2f42ba0c5ab2f50d768a1438ec23a1585bcd4d2 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -2,15 +2,13 @@ load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") -load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external") +load("//third_party/mkl:build_defs.bzl", "mkl_repository") +load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", + "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") load("//third_party/py:python_configure.bzl", "python_configure") - -load("//third_party:polymer.bzl", "tensorboard_polymer_workspace") -load("//third_party:python.bzl", "tensorboard_python_workspace") -load("//third_party:js.bzl", "tensorboard_js_workspace") -load("//third_party:typings.bzl", "tensorboard_typings_workspace") +load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", + "arm_compiler_configure") def _is_windows(repository_ctx): @@ -88,7 +86,6 @@ temp_workaround_http_archive = repository_rule( }, ) - # Executes specified command with arguments and calls 'fail' if it exited with # non-zero code def _execute_and_check_ret_code(repo_ctx, cmd_and_args): @@ -146,16 +143,29 @@ def tf_workspace(path_prefix="", tf_repo_name=""): cuda_configure(name="local_config_cuda") sycl_configure(name="local_config_sycl") python_configure(name="local_config_python") + + # Point //external/local_config_arm_compiler to //external/arm_compiler + arm_compiler_configure( + name="local_config_arm_compiler", + remote_config_repo="../arm_compiler", + build_file = str(Label("//third_party/toolchains/cpus/arm:BUILD"))) + + mkl_repository( + name = "mkl", + urls = [ + "http://mirror.bazel.build/github.com/01org/mkl-dnn/releases/download/v0.7/mklml_lnx_2018.0.20170425.tgz", + "https://github.com/01org/mkl-dnn/releases/download/v0.7/mklml_lnx_2018.0.20170425.tgz", + ], + sha256 = "3cc2501fb209e1fd0960a5f61c919438f9619c68a644dcebf0fdf69b07460c57", + strip_prefix = "mklml_lnx_2018.0.20170425", + build_file = str(Label("//third_party/mkl:mkl.BUILD")), + repository = tf_repo_name, + ) + if path_prefix: print("path_prefix was specified to tf_workspace but is no longer used " + "and will be removed in the future.") - # TODO(dandelion): Take these out when TB exits TF - tensorboard_polymer_workspace() - tensorboard_python_workspace() - tensorboard_typings_workspace() - tensorboard_js_workspace() - native.new_http_archive( name = "eigen_archive", urls = [ @@ -167,6 +177,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:eigen.BUILD")), ) + native.new_http_archive( + name = "arm_compiler", + build_file = str(Label("//:arm_compiler.BUILD")), + sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969", + strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf", + urls = [ + "http://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", + "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", + ], + ) + native.new_http_archive( name = "libxsmm_archive", urls = [ @@ -217,11 +238,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "farmhash_archive", urls = [ - "http://mirror.bazel.build/github.com/google/farmhash/archive/92e897b282426729f4724d91a637596c7e2fe28f.zip", - "https://github.com/google/farmhash/archive/92e897b282426729f4724d91a637596c7e2fe28f.zip", + "http://mirror.bazel.build/github.com/google/farmhash/archive/23eecfbe7e84ebf2e229bd02248f431c36e12f1a.zip", + "https://github.com/google/farmhash/archive/23eecfbe7e84ebf2e229bd02248f431c36e12f1a.zip", ], - sha256 = "4c626d1f306bda2c6804ab955892f803f5245f4dcaecb4979dc08b091256da54", - strip_prefix = "farmhash-92e897b282426729f4724d91a637596c7e2fe28f", + sha256 = "55215f8cd3ddbe9781f6fe5cc228731d6dcc8301b6191c6d420034c3fff1cb8d", + strip_prefix = "farmhash-23eecfbe7e84ebf2e229bd02248f431c36e12f1a", build_file = str(Label("//third_party:farmhash.BUILD")), ) @@ -291,26 +312,59 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "six_archive", urls = [ "http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", - "http://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", ], sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", strip_prefix = "six-1.10.0", build_file = str(Label("//third_party:six.BUILD")), ) + native.new_http_archive( + name = "org_python_pypi_backports_weakref", + urls = [ + "http://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + ], + sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892", + strip_prefix = "backports.weakref-1.0rc1/src", + build_file = str(Label("//third_party:backports_weakref.BUILD")), + ) + + native.new_http_archive( + name = "com_github_andreif_codegen", + urls = [ + "http://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", + "https://github.com/andreif/codegen/archive/1.0.tar.gz", + ], + sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee", + strip_prefix = "codegen-1.0", + build_file = str(Label("//third_party:codegen.BUILD")), + ) + + filegroup_external( + name = "org_python_license", + licenses = ["notice"], # Python 2.0 + sha256_urls = { + "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [ + "http://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", + "https://docs.python.org/2.7/_sources/license.txt", + ], + }, + ) + native.bind( name = "six", actual = "@six_archive//:six", ) patched_http_archive( - name = "protobuf", + name = "protobuf_archive", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz", - "https://github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz", + "https://github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", ], - sha256 = "e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0", - strip_prefix = "protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a", + sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", + strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", # TODO: remove patching when tensorflow stops linking same protos into # multiple shared libraries loaded in runtime by python. # This patch fixes a runtime crash when tensorflow is compiled @@ -318,27 +372,32 @@ def tf_workspace(path_prefix="", tf_repo_name=""): patch_file = str(Label("//third_party/protobuf:add_noinlines.patch")), ) + native.bind( + name = "protobuf", + actual = "@protobuf_archive//:protobuf", + ) + # We need to import the protobuf library under the names com_google_protobuf # and com_google_protobuf_cc to enable proto_library support in bazel. # Unfortunately there is no way to alias http_archives at the moment. native.http_archive( name = "com_google_protobuf", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz", - "https://github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz", + "https://github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", ], - sha256 = "e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0", - strip_prefix = "protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a", + sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", + strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", ) native.http_archive( name = "com_google_protobuf_cc", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz", - "https://github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz", + "https://github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", ], - sha256 = "e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0", - strip_prefix = "protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a", + sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", + strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", ) native.new_http_archive( @@ -416,23 +475,30 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # to point to the protobuf's compiler library. native.bind( name = "protobuf_clib", - actual = "@protobuf//:protoc_lib", + actual = "@protobuf_archive//:protoc_lib", ) native.bind( - name = "protobuf_compiler", - actual = "@protobuf//:protoc_lib", + name = "libssl", + actual = "@boringssl//:ssl", ) - native.new_http_archive( + # gRPC has includes directly from their third_party path for nanopb, so we + # must depend on their version of it. + native.bind( + name = "nanopb", + actual = "@grpc//third_party/nanopb:nanopb", + ) + + patched_http_archive( name = "grpc", urls = [ - "http://mirror.bazel.build/github.com/grpc/grpc/archive/d7ff4ff40071d2b486a052183e3e9f9382afb745.tar.gz", - "https://github.com/grpc/grpc/archive/d7ff4ff40071d2b486a052183e3e9f9382afb745.tar.gz", + "http://mirror.bazel.build/github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", + "https://github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", ], - sha256 = "a15f352436ab92c521b1ac11e729e155ace38d0856380cf25048c5d1d9ba8e31", - strip_prefix = "grpc-d7ff4ff40071d2b486a052183e3e9f9382afb745", - build_file = str(Label("//third_party:grpc.BUILD")), + sha256 = "2004635e6a078acfac8ffa71738397796be4f8fb72f572cc44ecee5d99511d9f", + strip_prefix = "grpc-781fd6f6ea03645a520cd5c675da67ab61f87e4b", + patch_file = str(Label("//third_party/grpc:grpc.patch")), ) # protobuf expects //external:grpc_cpp_plugin to point to grpc's @@ -463,11 +529,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "llvm", urls = [ - "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/e156d99231a7735d06a97b5b83de70bf4ce4f034.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/e156d99231a7735d06a97b5b83de70bf4ce4f034.tar.gz", + "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/4d98985c94c36b9eb4396c91fe0a72a0c5f707b2.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/4d98985c94c36b9eb4396c91fe0a72a0c5f707b2.tar.gz", ], - sha256 = "72e34e2411a06d4200a2688ee83832805fbef23a12ea481f31c2b8866fde007a", - strip_prefix = "llvm-e156d99231a7735d06a97b5b83de70bf4ce4f034", + sha256 = "1a085c995522fa19900568c03eb595b425df53842c7f281e3ab79aaa04affffa", + strip_prefix = "llvm-4d98985c94c36b9eb4396c91fe0a72a0c5f707b2", build_file = str(Label("//third_party/llvm:llvm.BUILD")), repository = tf_repo_name, ) @@ -499,7 +565,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): actual = "@jsoncpp_git//:jsoncpp", ) - native.http_archive( + patched_http_archive( name = "boringssl", urls = [ "http://mirror.bazel.build/github.com/google/boringssl/archive/bbcaa15b0647816b9a1a9b9e0d209cd6712f0105.tar.gz", @@ -507,22 +573,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "025264d6e9a7ad371f2f66d17a28b6627de0c9592dc2eb54afd062f68f1f9aa3", strip_prefix = "boringssl-bbcaa15b0647816b9a1a9b9e0d209cd6712f0105", - ) - - native.new_http_archive( - name = "nanopb_git", - urls = [ - "http://mirror.bazel.build/github.com/nanopb/nanopb/archive/1251fa1065afc0d62f635e0f63fec8276e14e13c.tar.gz", - "https://github.com/nanopb/nanopb/archive/1251fa1065afc0d62f635e0f63fec8276e14e13c.tar.gz", - ], - sha256 = "ab1455c8edff855f4f55b68480991559e51c11e7dab060bbab7cffb12dd3af33", - strip_prefix = "nanopb-1251fa1065afc0d62f635e0f63fec8276e14e13c", - build_file = str(Label("//third_party:nanopb.BUILD")), - ) - native.bind( - name = "nanopb", - actual = "@nanopb_git//:nanopb", + # Add patch to boringssl code to support s390x + patch_file = str(Label("//third_party/boringssl:add_boringssl_s390x.patch")), ) native.new_http_archive( @@ -623,3 +676,28 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:pprof.BUILD")), ) + native.new_http_archive( + name = "cub_archive", + urls = [ + "http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.6.4.zip", + "https://github.com/NVlabs/cub/archive/1.6.4.zip", + ], + sha256 = "966d0c4f41e2bdc81aebf9ccfbf0baffaac5a74f00b826b06f4dee79b2bb8cee", + strip_prefix = "cub-1.6.4", + build_file = str(Label("//third_party:cub.BUILD")), + ) + + native.bind( + name = "cub", + actual = "@cub_archive//:cub", + ) + + native.http_archive( + name = "bazel_toolchains", + urls = [ + "http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/bccee4855c049d34bac481083b4c68e2fab8cc50.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/bccee4855c049d34bac481083b4c68e2fab8cc50.tar.gz", + ], + sha256 = "3903fd93b96b42067e00b7973a2c16c34e761ad7a0b55e1557d408f352849e41", + strip_prefix = "bazel-toolchains-bccee4855c049d34bac481083b4c68e2fab8cc50", + ) diff --git a/third_party/backports_weakref.BUILD b/third_party/backports_weakref.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0adfc5f05419e736b6af01252674e6fb11e6b8d7 --- /dev/null +++ b/third_party/backports_weakref.BUILD @@ -0,0 +1,22 @@ +# Description: +# Backport of new features in Python's weakref module. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Python 2.0 + +py_library( + name = "org_python_pypi_backports_weakref", + srcs = [ + "backports/__init__.py", + "backports/weakref.py", + ], + srcs_version = "PY2AND3", +) + +genrule( + name = "license", + srcs = ["@org_python_license"], + outs = ["LICENSE"], + cmd = "cp $< $@", +) diff --git a/third_party/bleach.BUILD b/third_party/bleach.BUILD deleted file mode 100644 index 1bf75b84a769642d74b9fdef78708eaffceb113e..0000000000000000000000000000000000000000 --- a/third_party/bleach.BUILD +++ /dev/null @@ -1,20 +0,0 @@ -# Description: -# Build file for Bleach. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_library( - name = "org_mozilla_bleach", - srcs = [ - "bleach/__init__.py", - "bleach/callbacks.py", - "bleach/encoding.py", - "bleach/sanitizer.py", - "bleach/version.py", - ], - srcs_version = "PY2AND3", - deps = ["@org_html5lib"], -) diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_unsupported.json b/third_party/boringssl/BUILD similarity index 100% rename from tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_unsupported.json rename to third_party/boringssl/BUILD diff --git a/third_party/boringssl/add_boringssl_s390x.patch b/third_party/boringssl/add_boringssl_s390x.patch new file mode 100644 index 0000000000000000000000000000000000000000..0b41a4aa96831540bb55c69337bac1ed7b7cd651 --- /dev/null +++ b/third_party/boringssl/add_boringssl_s390x.patch @@ -0,0 +1,13 @@ +diff --git a/src/include/openssl/base.h b/src/include/openssl/base.h +index 7a3adfb..88012ad 100644 +--- a/src/include/openssl/base.h ++++ b/src/include/openssl/base.h +@@ -94,6 +94,8 @@ extern "C" { + #elif defined(__pnacl__) + #define OPENSSL_32_BIT + #define OPENSSL_PNACL ++#elif defined(__s390x__) ++#define OPENSSL_64_BIT + #else + #error "Unknown target CPU" + #endif diff --git a/third_party/clutz.BUILD b/third_party/clutz.BUILD deleted file mode 100644 index 593b70366a3a0908b91120ce5351fe7c2c0159b3..0000000000000000000000000000000000000000 --- a/third_party/clutz.BUILD +++ /dev/null @@ -1,44 +0,0 @@ -# Description: -# Build tool for making TypeScript .d.ts files from Closure JavaScript. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # MIT - -exports_files([ - "LICENSE", - "src/resources/closure.lib.d.ts", -]) - -JVM_FLAGS = [ - "-Xss20m", # JSCompiler needs big stacks for recursive parsing - "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive -] - -java_binary( - name = "clutz", - srcs = glob(["src/main/java/com/google/javascript/clutz/**/*.java"]), - jvm_flags = JVM_FLAGS, - main_class = "com.google.javascript.clutz.DeclarationGenerator", - deps = [ - "@args4j", - "@com_google_code_findbugs_jsr305", - "@com_google_code_gson", - "@com_google_guava", - "@com_google_javascript_closure_compiler", - ], -) - -java_binary( - name = "gents", - srcs = glob(["src/main/java/com/google/javascript/gents/**/*.java"]), - jvm_flags = JVM_FLAGS, - main_class = "com.google.javascript.gents.TypeScriptGenerator", - deps = [ - "@args4j", - "@com_google_code_findbugs_jsr305", - "@com_google_code_gson", - "@com_google_guava", - "@com_google_javascript_closure_compiler", - ], -) diff --git a/third_party/clutz.bzl b/third_party/clutz.bzl deleted file mode 100644 index f273c78c794c637f96af52c1c1aa96b31acc5a24..0000000000000000000000000000000000000000 --- a/third_party/clutz.bzl +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Build definitions for TypeScript from Closure JavaScript libraries.""" - -load("@io_bazel_rules_closure//closure/private:defs.bzl", - "JS_FILE_TYPE", - "collect_js", - "unfurl") - -CLUTZ_ATTRIBUTES = { - "_clutz": attr.label( - default=Label("@io_angular_clutz//:clutz"), - executable=True, - cfg="host"), - "_clutz_externs": attr.label( - default=Label("@com_google_javascript_closure_compiler_externs"), - allow_files=True), -} - -def extract_dts_from_closure_libraries(ctx): - """Extracts type definitions from closure dependencies. - - This just generates one big .d.ts file for all transitive Closure sources, - and does not pass it down. That means each rule has to duplicate the effort, - but on the other hand allows transitive dependencies on shared rules without - causing duplicate definition errors. - - Args: - ctx: A Skylark context. - Returns: - The generated Clutz typings file, or None if there were no JS deps. - """ - deps = unfurl(ctx.attr.deps, provider="closure_js_library") - js = collect_js(ctx, deps) - if not js.srcs: - return None - js_typings = ctx.new_file(ctx.bin_dir, "%s-js-typings.d.ts" % ctx.label.name) - srcs = depset(JS_FILE_TYPE.filter(ctx.files._clutz_externs)) + js.srcs - args = ["-o", js_typings.path] - for src in srcs: - args.append(src.path) - if getattr(ctx.attr, "clutz_entry_points", None): - args.append("--closure_entry_points") - args.extend(ctx.attr.clutz_entry_points) - ctx.action( - inputs=list(srcs), - outputs=[js_typings], - executable=ctx.executable._clutz, - arguments=args, - mnemonic="Clutz", - progress_message="Running Clutz on %d JS files %s" % ( - len(srcs), ctx.label)) - return js_typings - -################################################################################ -# The following definitions are for API compatibility with internal clutz.bzl - -CLUTZ_OUTPUTS = {} - -def _clutz_aspect_impl(target, ctx): - return struct() - -clutz_aspect = aspect( - implementation=_clutz_aspect_impl, - attr_aspects=["exports"]) diff --git a/third_party/codegen.BUILD b/third_party/codegen.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..df436c81635a71421a67fa8d8c84eb8dfcc97d7b --- /dev/null +++ b/third_party/codegen.BUILD @@ -0,0 +1,16 @@ +# -*- mode: python; -*- +# +# Description: +# Extension to ast that allow ast -> python code generation. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # New BSD + +exports_files(["LICENSE"]) + +py_library( + name = "com_github_andreif_codegen", + srcs = glob(["codegen.py"]), + srcs_version = "PY2AND3", +) diff --git a/third_party/cub.BUILD b/third_party/cub.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..29159c9dad3d32121ce05278821e41b39f3f2a20 --- /dev/null +++ b/third_party/cub.BUILD @@ -0,0 +1,26 @@ +# Description: CUB library which is a set of primitives for GPU programming. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # BSD + +exports_files(["LICENSE.TXT"]) + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda") + +filegroup( + name = "cub_header_files", + srcs = glob([ + "cub/**", + ]), +) + +cc_library( + name = "cub", + hdrs = if_cuda([":cub_header_files"]), + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/third_party/gif.BUILD b/third_party/gif.BUILD index ad6821af3ccb5b3b15151427c99db4280a6905bf..21c5c11a44dd7bdcb3bbea839c751fc9e6b7e8e0 100644 --- a/third_party/gif.BUILD +++ b/third_party/gif.BUILD @@ -20,6 +20,15 @@ cc_library( "lib/quantize.c", ], hdrs = ["lib/gif_lib.h"], + defines = select({ + #"@%ws%//tensorflow:android": [ + ":android": [ + "S_IREAD=S_IRUSR", + "S_IWRITE=S_IWUSR", + "S_IEXEC=S_IXUSR", + ], + "//conditions:default": [], + }), includes = ["lib/."], visibility = ["//visibility:public"], deps = select({ @@ -54,3 +63,10 @@ config_setting( "cpu": "x64_windows", }, ) + +config_setting( + name = "android", + values = {"crosstool_top": "//external:android/crosstool"}, +) + + diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 242439daf456d6fd31a140e5d2c56d3e89900652..2558f46fd55c35b5089cc0119f2654f598e5128a 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -213,7 +213,7 @@ def InvokeNvcc(argv, log=False): ' --compiler-options "' + host_compiler_options + '"' + ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + ' -I .' + - ' -x cu ' + includes + ' ' + srcs + ' -M -o ' + depfile) + ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile) if log: Log(cmd) exit_status = os.system(cmd) if exit_status != 0: diff --git a/third_party/gpus/crosstool/remote.BUILD.tpl b/third_party/gpus/crosstool/remote.BUILD.tpl new file mode 100644 index 0000000000000000000000000000000000000000..b2316331db257a39086bdd5ca02b5ca6848cebcb --- /dev/null +++ b/third_party/gpus/crosstool/remote.BUILD.tpl @@ -0,0 +1,10 @@ +# Description: +# Template for crosstool Build file to use a pre-generated config. +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +alias( + name = "toolchain", + actual = "%{remote_cuda_repo}:toolchain", +) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index f7610dd7a99e3c65ac494d23f0a408d4391680c0..b752734a08a1ac7a60582ebd7e60ec3c1564f353 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -40,20 +40,23 @@ config_setting( cc_library( name = "cuda_headers", hdrs = [ - "cuda_config.h", + "cuda/cuda_config.h", %{cuda_headers} ], includes = [ ".", - "include", + "cuda/include", ], visibility = ["//visibility:public"], ) cc_library( name = "cudart_static", - srcs = ["lib/%{cudart_static_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cudart_static_lib}"], + includes = [ + ".", + "cuda/include", + ], linkopts = select({ ":freebsd": [], "//conditions:default": ["-ldl"], @@ -66,95 +69,120 @@ cc_library( cc_library( name = "cuda_driver", - srcs = ["lib/%{cuda_driver_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cuda_driver_lib}"], + includes = [ + ".", + "cuda/include", + ], visibility = ["//visibility:public"], ) cc_library( name = "cudart", - srcs = ["lib/%{cudart_lib}"], - data = ["lib/%{cudart_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cudart_lib}"], + data = ["cuda/lib/%{cudart_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cublas", - srcs = ["lib/%{cublas_lib}"], - data = ["lib/%{cublas_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cublas_lib}"], + data = ["cuda/lib/%{cublas_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cusolver", - srcs = ["lib/%{cusolver_lib}"], - data = ["lib/%{cusolver_lib}"], - includes = ["include"], - linkstatic = 1, + srcs = ["cuda/lib/%{cusolver_lib}"], + data = ["cuda/lib/%{cusolver_lib}"], + includes = [ + ".", + "cuda/include", + ], linkopts = ["-lgomp"], + linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cudnn", - srcs = ["lib/%{cudnn_lib}"], - data = ["lib/%{cudnn_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cudnn_lib}"], + data = ["cuda/lib/%{cudnn_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cufft", - srcs = ["lib/%{cufft_lib}"], - data = ["lib/%{cufft_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cufft_lib}"], + data = ["cuda/lib/%{cufft_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "curand", - srcs = ["lib/%{curand_lib}"], - data = ["lib/%{curand_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{curand_lib}"], + data = ["cuda/lib/%{curand_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cuda", + visibility = ["//visibility:public"], deps = [ + ":cublas", ":cuda_headers", ":cudart", - ":cublas", ":cudnn", ":cufft", ":curand", ], - visibility = ["//visibility:public"], ) cc_library( name = "cupti_headers", hdrs = [ - "cuda_config.h", + "cuda/cuda_config.h", ":cuda-extras", ], includes = [ ".", - "extras/CUPTI/include/", + "cuda/extras/CUPTI/include/", ], visibility = ["//visibility:public"], ) cc_library( name = "cupti_dsos", - data = ["lib/%{cupti_lib}"], + data = ["cuda/lib/%{cupti_lib}"], + includes = [ + ".", + "cuda/include", + ], visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/remote.BUILD.tpl b/third_party/gpus/cuda/remote.BUILD.tpl new file mode 100644 index 0000000000000000000000000000000000000000..d88d512b90c352e6a301ed6efe8266d8dd6bf744 --- /dev/null +++ b/third_party/gpus/cuda/remote.BUILD.tpl @@ -0,0 +1,105 @@ +# Description: +# Template for cuda Build file to use a pre-generated config. +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "using_nvcc", + values = { + "define": "using_cuda_nvcc=true", + }, +) + +config_setting( + name = "using_clang", + values = { + "define": "using_cuda_clang=true", + }, +) + +# Equivalent to using_clang && -c opt. +config_setting( + name = "using_clang_opt", + values = { + "define": "using_cuda_clang=true", + "compilation_mode": "opt", + }, +) + +config_setting( + name = "darwin", + values = {"cpu": "darwin"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "freebsd", + values = {"cpu": "freebsd"}, + visibility = ["//visibility:public"], +) + +alias( + name = "cuda_headers", + actual = "%{remote_cuda_repo}cuda:cuda_headers", +) + +alias( + name = "cudart_static", + actual = "%{remote_cuda_repo}cuda:cudart_static", +) + +alias( + name = "cuda_driver", + actual = "%{remote_cuda_repo}cuda:cuda_driver", +) + +alias( + name = "cudart", + actual = "%{remote_cuda_repo}cuda:cudart", +) + +alias( + name = "cublas", + actual = "%{remote_cuda_repo}cuda:cublas", +) + +alias( + name = "cusolver", + actual = "%{remote_cuda_repo}cuda:cusolver", +) + +alias( + name = "cudnn", + actual = "%{remote_cuda_repo}cuda:cudnn", +) + +alias( + name = "cufft", + actual = "%{remote_cuda_repo}cuda:cufft", +) + +alias( + name = "curand", + actual = "%{remote_cuda_repo}cuda:curand", +) + +alias( + name = "cuda", + actual = "%{remote_cuda_repo}cuda:cuda", +) + +alias( + name = "cupti_headers", + actual = "%{remote_cuda_repo}cuda:cupti_headers", +) + +alias( + name = "cupti_dsos", + actual = "%{remote_cuda_repo}cuda:cupti_dsos", +) + +alias( + name = "libdevice_root", + actual = "%{remote_cuda_repo}cuda:libdevice_root", +) diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 61932a8e6d1a699392c4de73ee36ed681d9eda94..4dd3169d418797fbda656d33c53e3f147b38725d 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -26,6 +26,7 @@ _TF_CUDA_VERSION = "TF_CUDA_VERSION" _TF_CUDNN_VERSION = "TF_CUDNN_VERSION" _CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH" _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _DEFAULT_CUDA_VERSION = "" _DEFAULT_CUDNN_VERSION = "" @@ -739,19 +740,19 @@ def _create_dummy_repository(repository_ctx): # Create dummy files for the CUDA toolkit since they are still required by # tensorflow/core/platform/default/build_config:cuda. - repository_ctx.file("cuda/include/cuda.h", "") - repository_ctx.file("cuda/include/cublas.h", "") - repository_ctx.file("cuda/include/cudnn.h", "") - repository_ctx.file("cuda/extras/CUPTI/include/cupti.h", "") - repository_ctx.file("cuda/lib/%s" % _lib_name("cuda", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cudart", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cudart_static", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cublas", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cusolver", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cudnn", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("curand", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cufft", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cupti", cpu_value)) + repository_ctx.file("cuda/cuda/include/cuda.h", "") + repository_ctx.file("cuda/cuda/include/cublas.h", "") + repository_ctx.file("cuda/cuda/include/cudnn.h", "") + repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h", "") + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cuda", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart_static", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cusolver", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudnn", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("curand", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cufft", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cupti", cpu_value)) # Set up cuda_config.h, which is used by # tensorflow/stream_executor/dso_loader.cc. @@ -763,7 +764,7 @@ def _create_dummy_repository(repository_ctx): "CudaVersion(\"%s\")" % c for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES]), "%{cuda_toolkit_path}": _DEFAULT_CUDA_TOOLKIT_PATH, - }) + }, "cuda/cuda/cuda_config.h") # If cuda_configure is not configured to build with GPU support, and the user # attempts to build with --config=cuda, add a dummy build rule to intercept @@ -820,6 +821,13 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, dest_files = files.replace(src_dir, '').splitlines() src_files = files.splitlines() command = [] + if not _is_windows(repository_ctx): + # We clear folders that might have been generated previously to avoid + # undesired inclusions + command.append('if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi') + command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi') + command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi') + command.append('if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi') outs = [] for i in range(len(dest_files)): if dest_files[i] != "": @@ -829,7 +837,7 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, # On Windows, symlink is not supported, so we just copy all the files. cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s' command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest)) - outs.append(' "' + dest_dir + dest_files[i] + '",') + outs.append(' "' + dest_dir + dest_files[i] + '",') genrule = _genrule(src_dir, genrule_name, " && ".join(command), "\n".join(outs)) return genrule @@ -846,11 +854,11 @@ def _genrule(src_dir, genrule_name, command, outs): genrule_name + '",\n' + ' outs = [\n' + outs + - ' ],\n' + + '\n ],\n' + ' cmd = """\n' + command + - ' """,\n' + - ')\n\n' + '\n """,\n' + + ')\n' ) @@ -883,15 +891,16 @@ def _use_cuda_clang(repository_ctx): return enable_cuda == "1" return False -def _compute_cuda_extra_copts(repository_ctx, cuda_config): +def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): if _use_cuda_clang(repository_ctx): - capability_flags = ["--cuda-gpu-arch=sm_" + cap.replace(".", "") for cap in cuda_config.compute_capabilities] + capability_flags = ["--cuda-gpu-arch=sm_" + + cap.replace(".", "") for cap in compute_capabilities] else: # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc capability_flags = [] return str(capability_flags) -def _create_cuda_repository(repository_ctx): +def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" cuda_config = _get_cuda_config(repository_ctx) @@ -904,19 +913,19 @@ def _create_cuda_repository(repository_ctx): cuda_toolkit_path = cuda_config.cuda_toolkit_path cuda_include_path = cuda_toolkit_path + "/include" genrules = [_symlink_genrule_for_dir(repository_ctx, - cuda_include_path, "include", "cuda-include")] + cuda_include_path, "cuda/include", "cuda-include")] genrules.append(_symlink_genrule_for_dir(repository_ctx, - cuda_toolkit_path + "/nvvm", "nvvm", "cuda-nvvm")) + cuda_toolkit_path + "/nvvm", "cuda/nvvm", "cuda-nvvm")) genrules.append(_symlink_genrule_for_dir(repository_ctx, cuda_toolkit_path + "/extras/CUPTI/include", - "extras/CUPTI/include", "cuda-extras")) + "cuda/extras/CUPTI/include", "cuda-extras")) cuda_libs = _find_libs(repository_ctx, cuda_config) cuda_lib_src = [] cuda_lib_dest = [] for lib in cuda_libs.values(): cuda_lib_src.append(lib.path) - cuda_lib_dest.append("lib/" + lib.file_name) + cuda_lib_dest.append("cuda/lib/" + lib.file_name) genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib", cuda_lib_src, cuda_lib_dest)) @@ -925,8 +934,9 @@ def _create_cuda_repository(repository_ctx): included_files = _read_dir(repository_ctx, cuda_include_path).replace( cuda_include_path, '').splitlines() if '/cudnn.h' not in included_files: - genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "include/", - "cudnn-include", [cudnn_header_dir + "/cudnn.h"], ["cudnn.h"])) + genrules.append(_symlink_genrule_for_dir(repository_ctx, None, + "cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"], + ["cudnn.h"])) else: genrules.append( 'filegroup(\n' + @@ -939,7 +949,8 @@ def _create_cuda_repository(repository_ctx): _tpl(repository_ctx, "cuda:build_defs.bzl", { "%{cuda_is_configured}": "True", - "%{cuda_extra_copts}": _compute_cuda_extra_copts(repository_ctx, cuda_config), + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + repository_ctx, cuda_config.compute_capabilities), }) _tpl(repository_ctx, "cuda:BUILD", @@ -997,16 +1008,35 @@ def _create_cuda_repository(repository_ctx): ["CudaVersion(\"%s\")" % c for c in cuda_config.compute_capabilities]), "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path, - }) + }, "cuda/cuda/cuda_config.h") + +def _create_remote_cuda_repository(repository_ctx, remote_config_repo): + """Creates pointers to a remotely configured repo set up to build with CUDA.""" + _tpl(repository_ctx, "cuda:build_defs.bzl", + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + repository_ctx, _compute_capabilities(repository_ctx)), + }) + _tpl(repository_ctx, "cuda:remote.BUILD", + { + "%{remote_cuda_repo}": remote_config_repo, + }, "cuda/BUILD") + _tpl(repository_ctx, "crosstool:remote.BUILD", { + "%{remote_cuda_repo}": remote_config_repo, + }, "crosstool/BUILD") def _cuda_autoconf_impl(repository_ctx): """Implementation of the cuda_autoconf repository rule.""" if not _enable_cuda(repository_ctx): _create_dummy_repository(repository_ctx) else: - _create_cuda_repository(repository_ctx) - + if _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ: + _create_remote_cuda_repository(repository_ctx, + repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO]) + else: + _create_local_cuda_repository(repository_ctx) cuda_configure = repository_rule( @@ -1019,6 +1049,7 @@ cuda_configure = repository_rule( _TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_COMPUTE_CAPABILITIES, + _TF_CUDA_CONFIG_REPO, ], ) diff --git a/third_party/grpc.BUILD b/third_party/grpc.BUILD deleted file mode 100644 index b79259618f2f06c941b5a8e3427dd0d5a0fe1e40..0000000000000000000000000000000000000000 --- a/third_party/grpc.BUILD +++ /dev/null @@ -1,2478 +0,0 @@ -# NOTE(mrry): This file is an edited version of the following file: -# https://raw.githubusercontent.com/grpc/grpc/d7ff4ff40071d2b486a052183e3e9f9382afb745/BUILD -# ...with small modifications to fix the build rules for :grpc++_unsecure. -# -# TODO(mrry): Upstream these fixes back to the gRPC repository. -# TODO(jart): Fix nanopb's BUILD file. Fix grpc BUILD file. - -# GRPC Bazel BUILD file. -# This currently builds C, C++ and Objective-C code. -# This file has been automatically generated from a template file. -# Please look at the templates directory instead. -# This file can be regenerated from the template by running -# tools/buildgen/generate_projects.sh - -# Copyright 2015, Google Inc. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -licenses(["notice"]) # 3-clause BSD - -package(default_visibility = ["//visibility:public"]) - -exports_files(["LICENSE"]) - -genrule( - name = "pb_h", - outs = ["third_party/nanopb/pb.h"], - cmd = "echo '#include ' >$@", - visibility = ["//visibility:private"], -) - -genrule( - name = "pb_decode_h", - outs = ["third_party/nanopb/pb_decode.h"], - cmd = "echo '#include ' >$@", - visibility = ["//visibility:private"], -) - -genrule( - name = "pb_encode_h", - outs = ["third_party/nanopb/pb_encode.h"], - cmd = "echo '#include ' >$@", - visibility = ["//visibility:private"], -) - -cc_library( - name = "gpr", - srcs = [ - "src/core/lib/profiling/basic_timers.c", - "src/core/lib/profiling/stap_timers.c", - "src/core/lib/profiling/timers.h", - "src/core/lib/support/alloc.c", - "src/core/lib/support/avl.c", - "src/core/lib/support/backoff.c", - "src/core/lib/support/backoff.h", - "src/core/lib/support/block_annotate.h", - "src/core/lib/support/cmdline.c", - "src/core/lib/support/cpu_iphone.c", - "src/core/lib/support/cpu_linux.c", - "src/core/lib/support/cpu_posix.c", - "src/core/lib/support/cpu_windows.c", - "src/core/lib/support/env.h", - "src/core/lib/support/env_linux.c", - "src/core/lib/support/env_posix.c", - "src/core/lib/support/env_windows.c", - "src/core/lib/support/histogram.c", - "src/core/lib/support/host_port.c", - "src/core/lib/support/log.c", - "src/core/lib/support/log_android.c", - "src/core/lib/support/log_linux.c", - "src/core/lib/support/log_posix.c", - "src/core/lib/support/log_windows.c", - "src/core/lib/support/murmur_hash.c", - "src/core/lib/support/murmur_hash.h", - "src/core/lib/support/slice.c", - "src/core/lib/support/slice_buffer.c", - "src/core/lib/support/stack_lockfree.c", - "src/core/lib/support/stack_lockfree.h", - "src/core/lib/support/string.c", - "src/core/lib/support/string.h", - "src/core/lib/support/string_posix.c", - "src/core/lib/support/string_util_windows.c", - "src/core/lib/support/string_windows.c", - "src/core/lib/support/string_windows.h", - "src/core/lib/support/subprocess_posix.c", - "src/core/lib/support/subprocess_windows.c", - "src/core/lib/support/sync.c", - "src/core/lib/support/sync_posix.c", - "src/core/lib/support/sync_windows.c", - "src/core/lib/support/thd.c", - "src/core/lib/support/thd_internal.h", - "src/core/lib/support/thd_posix.c", - "src/core/lib/support/thd_windows.c", - "src/core/lib/support/time.c", - "src/core/lib/support/time_posix.c", - "src/core/lib/support/time_precise.c", - "src/core/lib/support/time_precise.h", - "src/core/lib/support/time_windows.c", - "src/core/lib/support/tls_pthread.c", - "src/core/lib/support/tmpfile.h", - "src/core/lib/support/tmpfile_msys.c", - "src/core/lib/support/tmpfile_posix.c", - "src/core/lib/support/tmpfile_windows.c", - "src/core/lib/support/wrap_memcpy.c", - ], - hdrs = [ - "include/grpc/impl/codegen/alloc.h", - "include/grpc/impl/codegen/atm.h", - "include/grpc/impl/codegen/atm_gcc_atomic.h", - "include/grpc/impl/codegen/atm_gcc_sync.h", - "include/grpc/impl/codegen/atm_windows.h", - "include/grpc/impl/codegen/log.h", - "include/grpc/impl/codegen/port_platform.h", - "include/grpc/impl/codegen/slice.h", - "include/grpc/impl/codegen/slice_buffer.h", - "include/grpc/impl/codegen/sync.h", - "include/grpc/impl/codegen/sync_generic.h", - "include/grpc/impl/codegen/sync_posix.h", - "include/grpc/impl/codegen/sync_windows.h", - "include/grpc/impl/codegen/time.h", - "include/grpc/support/alloc.h", - "include/grpc/support/atm.h", - "include/grpc/support/atm_gcc_atomic.h", - "include/grpc/support/atm_gcc_sync.h", - "include/grpc/support/atm_windows.h", - "include/grpc/support/avl.h", - "include/grpc/support/cmdline.h", - "include/grpc/support/cpu.h", - "include/grpc/support/histogram.h", - "include/grpc/support/host_port.h", - "include/grpc/support/log.h", - "include/grpc/support/log_windows.h", - "include/grpc/support/port_platform.h", - "include/grpc/support/slice.h", - "include/grpc/support/slice_buffer.h", - "include/grpc/support/string_util.h", - "include/grpc/support/subprocess.h", - "include/grpc/support/sync.h", - "include/grpc/support/sync_generic.h", - "include/grpc/support/sync_posix.h", - "include/grpc/support/sync_windows.h", - "include/grpc/support/thd.h", - "include/grpc/support/time.h", - "include/grpc/support/tls.h", - "include/grpc/support/tls_gcc.h", - "include/grpc/support/tls_msvc.h", - "include/grpc/support/tls_pthread.h", - "include/grpc/support/useful.h", - ], - includes = [ - ".", - "include", - ], - linkopts = ["-lpthread"], -) - -cc_library( - name = "grpc", - srcs = [ - "src/core/ext/census/aggregation.h", - "src/core/ext/census/census_interface.h", - "src/core/ext/census/census_rpc_stats.h", - "src/core/ext/census/context.c", - "src/core/ext/census/gen/census.pb.c", - "src/core/ext/census/gen/census.pb.h", - "src/core/ext/census/grpc_context.c", - "src/core/ext/census/grpc_filter.c", - "src/core/ext/census/grpc_filter.h", - "src/core/ext/census/grpc_plugin.c", - "src/core/ext/census/initialize.c", - "src/core/ext/census/mlog.c", - "src/core/ext/census/mlog.h", - "src/core/ext/census/operation.c", - "src/core/ext/census/placeholders.c", - "src/core/ext/census/rpc_metric_id.h", - "src/core/ext/census/tracing.c", - "src/core/ext/client_config/channel_connectivity.c", - "src/core/ext/client_config/client_channel.c", - "src/core/ext/client_config/client_channel.h", - "src/core/ext/client_config/client_channel_factory.c", - "src/core/ext/client_config/client_channel_factory.h", - "src/core/ext/client_config/client_config.c", - "src/core/ext/client_config/client_config.h", - "src/core/ext/client_config/client_config_plugin.c", - "src/core/ext/client_config/connector.c", - "src/core/ext/client_config/connector.h", - "src/core/ext/client_config/default_initial_connect_string.c", - "src/core/ext/client_config/initial_connect_string.c", - "src/core/ext/client_config/initial_connect_string.h", - "src/core/ext/client_config/lb_policy.c", - "src/core/ext/client_config/lb_policy.h", - "src/core/ext/client_config/lb_policy_factory.c", - "src/core/ext/client_config/lb_policy_factory.h", - "src/core/ext/client_config/lb_policy_registry.c", - "src/core/ext/client_config/lb_policy_registry.h", - "src/core/ext/client_config/parse_address.c", - "src/core/ext/client_config/parse_address.h", - "src/core/ext/client_config/resolver.c", - "src/core/ext/client_config/resolver.h", - "src/core/ext/client_config/resolver_factory.c", - "src/core/ext/client_config/resolver_factory.h", - "src/core/ext/client_config/resolver_registry.c", - "src/core/ext/client_config/resolver_registry.h", - "src/core/ext/client_config/subchannel.c", - "src/core/ext/client_config/subchannel.h", - "src/core/ext/client_config/subchannel_call_holder.c", - "src/core/ext/client_config/subchannel_call_holder.h", - "src/core/ext/client_config/subchannel_index.c", - "src/core/ext/client_config/subchannel_index.h", - "src/core/ext/client_config/uri_parser.c", - "src/core/ext/client_config/uri_parser.h", - "src/core/ext/lb_policy/grpclb/load_balancer_api.c", - "src/core/ext/lb_policy/grpclb/load_balancer_api.h", - "src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c", - "src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.h", - "src/core/ext/lb_policy/pick_first/pick_first.c", - "src/core/ext/lb_policy/round_robin/round_robin.c", - "src/core/ext/load_reporting/load_reporting.c", - "src/core/ext/load_reporting/load_reporting.h", - "src/core/ext/load_reporting/load_reporting_filter.c", - "src/core/ext/load_reporting/load_reporting_filter.h", - "src/core/ext/resolver/dns/native/dns_resolver.c", - "src/core/ext/resolver/sockaddr/sockaddr_resolver.c", - "src/core/ext/transport/chttp2/alpn/alpn.c", - "src/core/ext/transport/chttp2/alpn/alpn.h", - "src/core/ext/transport/chttp2/client/insecure/channel_create.c", - "src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c", - "src/core/ext/transport/chttp2/client/secure/secure_channel_create.c", - "src/core/ext/transport/chttp2/server/insecure/server_chttp2.c", - "src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c", - "src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c", - "src/core/ext/transport/chttp2/transport/bin_decoder.c", - "src/core/ext/transport/chttp2/transport/bin_decoder.h", - "src/core/ext/transport/chttp2/transport/bin_encoder.c", - "src/core/ext/transport/chttp2/transport/bin_encoder.h", - "src/core/ext/transport/chttp2/transport/chttp2_plugin.c", - "src/core/ext/transport/chttp2/transport/chttp2_transport.c", - "src/core/ext/transport/chttp2/transport/chttp2_transport.h", - "src/core/ext/transport/chttp2/transport/frame.h", - "src/core/ext/transport/chttp2/transport/frame_data.c", - "src/core/ext/transport/chttp2/transport/frame_data.h", - "src/core/ext/transport/chttp2/transport/frame_goaway.c", - "src/core/ext/transport/chttp2/transport/frame_goaway.h", - "src/core/ext/transport/chttp2/transport/frame_ping.c", - "src/core/ext/transport/chttp2/transport/frame_ping.h", - "src/core/ext/transport/chttp2/transport/frame_rst_stream.c", - "src/core/ext/transport/chttp2/transport/frame_rst_stream.h", - "src/core/ext/transport/chttp2/transport/frame_settings.c", - "src/core/ext/transport/chttp2/transport/frame_settings.h", - "src/core/ext/transport/chttp2/transport/frame_window_update.c", - "src/core/ext/transport/chttp2/transport/frame_window_update.h", - "src/core/ext/transport/chttp2/transport/hpack_encoder.c", - "src/core/ext/transport/chttp2/transport/hpack_encoder.h", - "src/core/ext/transport/chttp2/transport/hpack_parser.c", - "src/core/ext/transport/chttp2/transport/hpack_parser.h", - "src/core/ext/transport/chttp2/transport/hpack_table.c", - "src/core/ext/transport/chttp2/transport/hpack_table.h", - "src/core/ext/transport/chttp2/transport/http2_errors.h", - "src/core/ext/transport/chttp2/transport/huffsyms.c", - "src/core/ext/transport/chttp2/transport/huffsyms.h", - "src/core/ext/transport/chttp2/transport/incoming_metadata.c", - "src/core/ext/transport/chttp2/transport/incoming_metadata.h", - "src/core/ext/transport/chttp2/transport/internal.h", - "src/core/ext/transport/chttp2/transport/parsing.c", - "src/core/ext/transport/chttp2/transport/status_conversion.c", - "src/core/ext/transport/chttp2/transport/status_conversion.h", - "src/core/ext/transport/chttp2/transport/stream_lists.c", - "src/core/ext/transport/chttp2/transport/stream_map.c", - "src/core/ext/transport/chttp2/transport/stream_map.h", - "src/core/ext/transport/chttp2/transport/timeout_encoding.c", - "src/core/ext/transport/chttp2/transport/timeout_encoding.h", - "src/core/ext/transport/chttp2/transport/varint.c", - "src/core/ext/transport/chttp2/transport/varint.h", - "src/core/ext/transport/chttp2/transport/writing.c", - "src/core/lib/channel/channel_args.c", - "src/core/lib/channel/channel_args.h", - "src/core/lib/channel/channel_stack.c", - "src/core/lib/channel/channel_stack.h", - "src/core/lib/channel/channel_stack_builder.c", - "src/core/lib/channel/channel_stack_builder.h", - "src/core/lib/channel/compress_filter.c", - "src/core/lib/channel/compress_filter.h", - "src/core/lib/channel/connected_channel.c", - "src/core/lib/channel/connected_channel.h", - "src/core/lib/channel/context.h", - "src/core/lib/channel/http_client_filter.c", - "src/core/lib/channel/http_client_filter.h", - "src/core/lib/channel/http_server_filter.c", - "src/core/lib/channel/http_server_filter.h", - "src/core/lib/compression/algorithm_metadata.h", - "src/core/lib/compression/compression.c", - "src/core/lib/compression/message_compress.c", - "src/core/lib/compression/message_compress.h", - "src/core/lib/debug/trace.c", - "src/core/lib/debug/trace.h", - "src/core/lib/http/format_request.c", - "src/core/lib/http/format_request.h", - "src/core/lib/http/httpcli.c", - "src/core/lib/http/httpcli.h", - "src/core/lib/http/httpcli_security_connector.c", - "src/core/lib/http/parser.c", - "src/core/lib/http/parser.h", - "src/core/lib/iomgr/closure.c", - "src/core/lib/iomgr/closure.h", - "src/core/lib/iomgr/endpoint.c", - "src/core/lib/iomgr/endpoint.h", - "src/core/lib/iomgr/endpoint_pair.h", - "src/core/lib/iomgr/endpoint_pair_posix.c", - "src/core/lib/iomgr/endpoint_pair_windows.c", - "src/core/lib/iomgr/error.c", - "src/core/lib/iomgr/error.h", - "src/core/lib/iomgr/ev_epoll_linux.c", - "src/core/lib/iomgr/ev_epoll_linux.h", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.c", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.h", - "src/core/lib/iomgr/ev_poll_posix.c", - "src/core/lib/iomgr/ev_poll_posix.h", - "src/core/lib/iomgr/ev_posix.c", - "src/core/lib/iomgr/ev_posix.h", - "src/core/lib/iomgr/exec_ctx.c", - "src/core/lib/iomgr/exec_ctx.h", - "src/core/lib/iomgr/executor.c", - "src/core/lib/iomgr/executor.h", - "src/core/lib/iomgr/iocp_windows.c", - "src/core/lib/iomgr/iocp_windows.h", - "src/core/lib/iomgr/iomgr.c", - "src/core/lib/iomgr/iomgr.h", - "src/core/lib/iomgr/iomgr_internal.h", - "src/core/lib/iomgr/iomgr_posix.c", - "src/core/lib/iomgr/iomgr_posix.h", - "src/core/lib/iomgr/iomgr_windows.c", - "src/core/lib/iomgr/load_file.c", - "src/core/lib/iomgr/load_file.h", - "src/core/lib/iomgr/network_status_tracker.c", - "src/core/lib/iomgr/network_status_tracker.h", - "src/core/lib/iomgr/polling_entity.c", - "src/core/lib/iomgr/polling_entity.h", - "src/core/lib/iomgr/pollset.h", - "src/core/lib/iomgr/pollset_set.h", - "src/core/lib/iomgr/pollset_set_windows.c", - "src/core/lib/iomgr/pollset_set_windows.h", - "src/core/lib/iomgr/pollset_windows.c", - "src/core/lib/iomgr/pollset_windows.h", - "src/core/lib/iomgr/resolve_address.h", - "src/core/lib/iomgr/resolve_address_posix.c", - "src/core/lib/iomgr/resolve_address_windows.c", - "src/core/lib/iomgr/sockaddr.h", - "src/core/lib/iomgr/sockaddr_posix.h", - "src/core/lib/iomgr/sockaddr_utils.c", - "src/core/lib/iomgr/sockaddr_utils.h", - "src/core/lib/iomgr/sockaddr_windows.h", - "src/core/lib/iomgr/socket_utils_common_posix.c", - "src/core/lib/iomgr/socket_utils_linux.c", - "src/core/lib/iomgr/socket_utils_posix.c", - "src/core/lib/iomgr/socket_utils_posix.h", - "src/core/lib/iomgr/socket_windows.c", - "src/core/lib/iomgr/socket_windows.h", - "src/core/lib/iomgr/tcp_client.h", - "src/core/lib/iomgr/tcp_client_posix.c", - "src/core/lib/iomgr/tcp_client_windows.c", - "src/core/lib/iomgr/tcp_posix.c", - "src/core/lib/iomgr/tcp_posix.h", - "src/core/lib/iomgr/tcp_server.h", - "src/core/lib/iomgr/tcp_server_posix.c", - "src/core/lib/iomgr/tcp_server_windows.c", - "src/core/lib/iomgr/tcp_windows.c", - "src/core/lib/iomgr/tcp_windows.h", - "src/core/lib/iomgr/time_averaged_stats.c", - "src/core/lib/iomgr/time_averaged_stats.h", - "src/core/lib/iomgr/timer.c", - "src/core/lib/iomgr/timer.h", - "src/core/lib/iomgr/timer_heap.c", - "src/core/lib/iomgr/timer_heap.h", - "src/core/lib/iomgr/udp_server.c", - "src/core/lib/iomgr/udp_server.h", - "src/core/lib/iomgr/unix_sockets_posix.c", - "src/core/lib/iomgr/unix_sockets_posix.h", - "src/core/lib/iomgr/unix_sockets_posix_noop.c", - "src/core/lib/iomgr/wakeup_fd_eventfd.c", - "src/core/lib/iomgr/wakeup_fd_nospecial.c", - "src/core/lib/iomgr/wakeup_fd_pipe.c", - "src/core/lib/iomgr/wakeup_fd_pipe.h", - "src/core/lib/iomgr/wakeup_fd_posix.c", - "src/core/lib/iomgr/wakeup_fd_posix.h", - "src/core/lib/iomgr/workqueue.h", - "src/core/lib/iomgr/workqueue_posix.c", - "src/core/lib/iomgr/workqueue_posix.h", - "src/core/lib/iomgr/workqueue_windows.c", - "src/core/lib/iomgr/workqueue_windows.h", - "src/core/lib/json/json.c", - "src/core/lib/json/json.h", - "src/core/lib/json/json_common.h", - "src/core/lib/json/json_reader.c", - "src/core/lib/json/json_reader.h", - "src/core/lib/json/json_string.c", - "src/core/lib/json/json_writer.c", - "src/core/lib/json/json_writer.h", - "src/core/lib/security/context/security_context.c", - "src/core/lib/security/context/security_context.h", - "src/core/lib/security/credentials/composite/composite_credentials.c", - "src/core/lib/security/credentials/composite/composite_credentials.h", - "src/core/lib/security/credentials/credentials.c", - "src/core/lib/security/credentials/credentials.h", - "src/core/lib/security/credentials/credentials_metadata.c", - "src/core/lib/security/credentials/fake/fake_credentials.c", - "src/core/lib/security/credentials/fake/fake_credentials.h", - "src/core/lib/security/credentials/google_default/credentials_posix.c", - "src/core/lib/security/credentials/google_default/credentials_windows.c", - "src/core/lib/security/credentials/google_default/google_default_credentials.c", - "src/core/lib/security/credentials/google_default/google_default_credentials.h", - "src/core/lib/security/credentials/iam/iam_credentials.c", - "src/core/lib/security/credentials/iam/iam_credentials.h", - "src/core/lib/security/credentials/jwt/json_token.c", - "src/core/lib/security/credentials/jwt/json_token.h", - "src/core/lib/security/credentials/jwt/jwt_credentials.c", - "src/core/lib/security/credentials/jwt/jwt_credentials.h", - "src/core/lib/security/credentials/jwt/jwt_verifier.c", - "src/core/lib/security/credentials/jwt/jwt_verifier.h", - "src/core/lib/security/credentials/oauth2/oauth2_credentials.c", - "src/core/lib/security/credentials/oauth2/oauth2_credentials.h", - "src/core/lib/security/credentials/plugin/plugin_credentials.c", - "src/core/lib/security/credentials/plugin/plugin_credentials.h", - "src/core/lib/security/credentials/ssl/ssl_credentials.c", - "src/core/lib/security/credentials/ssl/ssl_credentials.h", - "src/core/lib/security/transport/auth_filters.h", - "src/core/lib/security/transport/client_auth_filter.c", - "src/core/lib/security/transport/handshake.c", - "src/core/lib/security/transport/handshake.h", - "src/core/lib/security/transport/secure_endpoint.c", - "src/core/lib/security/transport/secure_endpoint.h", - "src/core/lib/security/transport/security_connector.c", - "src/core/lib/security/transport/security_connector.h", - "src/core/lib/security/transport/server_auth_filter.c", - "src/core/lib/security/transport/tsi_error.c", - "src/core/lib/security/transport/tsi_error.h", - "src/core/lib/security/util/b64.c", - "src/core/lib/security/util/b64.h", - "src/core/lib/security/util/json_util.c", - "src/core/lib/security/util/json_util.h", - "src/core/lib/surface/alarm.c", - "src/core/lib/surface/api_trace.c", - "src/core/lib/surface/api_trace.h", - "src/core/lib/surface/byte_buffer.c", - "src/core/lib/surface/byte_buffer_reader.c", - "src/core/lib/surface/call.c", - "src/core/lib/surface/call.h", - "src/core/lib/surface/call_details.c", - "src/core/lib/surface/call_log_batch.c", - "src/core/lib/surface/call_test_only.h", - "src/core/lib/surface/channel.c", - "src/core/lib/surface/channel.h", - "src/core/lib/surface/channel_init.c", - "src/core/lib/surface/channel_init.h", - "src/core/lib/surface/channel_ping.c", - "src/core/lib/surface/channel_stack_type.c", - "src/core/lib/surface/channel_stack_type.h", - "src/core/lib/surface/completion_queue.c", - "src/core/lib/surface/completion_queue.h", - "src/core/lib/surface/event_string.c", - "src/core/lib/surface/event_string.h", - "src/core/lib/surface/init.c", - "src/core/lib/surface/init.h", - "src/core/lib/surface/init_secure.c", - "src/core/lib/surface/lame_client.c", - "src/core/lib/surface/lame_client.h", - "src/core/lib/surface/metadata_array.c", - "src/core/lib/surface/server.c", - "src/core/lib/surface/server.h", - "src/core/lib/surface/validate_metadata.c", - "src/core/lib/surface/version.c", - "src/core/lib/transport/byte_stream.c", - "src/core/lib/transport/byte_stream.h", - "src/core/lib/transport/connectivity_state.c", - "src/core/lib/transport/connectivity_state.h", - "src/core/lib/transport/metadata.c", - "src/core/lib/transport/metadata.h", - "src/core/lib/transport/metadata_batch.c", - "src/core/lib/transport/metadata_batch.h", - "src/core/lib/transport/static_metadata.c", - "src/core/lib/transport/static_metadata.h", - "src/core/lib/transport/transport.c", - "src/core/lib/transport/transport.h", - "src/core/lib/transport/transport_impl.h", - "src/core/lib/transport/transport_op_string.c", - "src/core/lib/tsi/fake_transport_security.c", - "src/core/lib/tsi/fake_transport_security.h", - "src/core/lib/tsi/ssl_transport_security.c", - "src/core/lib/tsi/ssl_transport_security.h", - "src/core/lib/tsi/ssl_types.h", - "src/core/lib/tsi/transport_security.c", - "src/core/lib/tsi/transport_security.h", - "src/core/lib/tsi/transport_security_interface.h", - "src/core/plugin_registry/grpc_plugin_registry.c", - "third_party/nanopb/pb.h", - "third_party/nanopb/pb_decode.h", - "third_party/nanopb/pb_encode.h", - ], - hdrs = [ - "include/grpc/byte_buffer.h", - "include/grpc/byte_buffer_reader.h", - "include/grpc/census.h", - "include/grpc/compression.h", - "include/grpc/grpc.h", - "include/grpc/grpc_posix.h", - "include/grpc/grpc_security.h", - "include/grpc/grpc_security_constants.h", - "include/grpc/impl/codegen/alloc.h", - "include/grpc/impl/codegen/atm.h", - "include/grpc/impl/codegen/atm_gcc_atomic.h", - "include/grpc/impl/codegen/atm_gcc_sync.h", - "include/grpc/impl/codegen/atm_windows.h", - "include/grpc/impl/codegen/byte_buffer.h", - "include/grpc/impl/codegen/byte_buffer_reader.h", - "include/grpc/impl/codegen/compression_types.h", - "include/grpc/impl/codegen/connectivity_state.h", - "include/grpc/impl/codegen/grpc_types.h", - "include/grpc/impl/codegen/log.h", - "include/grpc/impl/codegen/port_platform.h", - "include/grpc/impl/codegen/propagation_bits.h", - "include/grpc/impl/codegen/slice.h", - "include/grpc/impl/codegen/slice_buffer.h", - "include/grpc/impl/codegen/status.h", - "include/grpc/impl/codegen/sync.h", - "include/grpc/impl/codegen/sync_generic.h", - "include/grpc/impl/codegen/sync_posix.h", - "include/grpc/impl/codegen/sync_windows.h", - "include/grpc/impl/codegen/time.h", - "include/grpc/status.h", - ], - copts = [ - "-std=gnu99", - ], - includes = [ - ".", - "include", - ], - deps = [ - ":gpr", - "//external:libssl", - "//external:nanopb", - "//external:zlib", - ], -) - -cc_library( - name = "grpc_cronet", - srcs = [ - "src/core/ext/client_config/channel_connectivity.c", - "src/core/ext/client_config/client_channel.c", - "src/core/ext/client_config/client_channel.h", - "src/core/ext/client_config/client_channel_factory.c", - "src/core/ext/client_config/client_channel_factory.h", - "src/core/ext/client_config/client_config.c", - "src/core/ext/client_config/client_config.h", - "src/core/ext/client_config/client_config_plugin.c", - "src/core/ext/client_config/connector.c", - "src/core/ext/client_config/connector.h", - "src/core/ext/client_config/default_initial_connect_string.c", - "src/core/ext/client_config/initial_connect_string.c", - "src/core/ext/client_config/initial_connect_string.h", - "src/core/ext/client_config/lb_policy.c", - "src/core/ext/client_config/lb_policy.h", - "src/core/ext/client_config/lb_policy_factory.c", - "src/core/ext/client_config/lb_policy_factory.h", - "src/core/ext/client_config/lb_policy_registry.c", - "src/core/ext/client_config/lb_policy_registry.h", - "src/core/ext/client_config/parse_address.c", - "src/core/ext/client_config/parse_address.h", - "src/core/ext/client_config/resolver.c", - "src/core/ext/client_config/resolver.h", - "src/core/ext/client_config/resolver_factory.c", - "src/core/ext/client_config/resolver_factory.h", - "src/core/ext/client_config/resolver_registry.c", - "src/core/ext/client_config/resolver_registry.h", - "src/core/ext/client_config/subchannel.c", - "src/core/ext/client_config/subchannel.h", - "src/core/ext/client_config/subchannel_call_holder.c", - "src/core/ext/client_config/subchannel_call_holder.h", - "src/core/ext/client_config/subchannel_index.c", - "src/core/ext/client_config/subchannel_index.h", - "src/core/ext/client_config/uri_parser.c", - "src/core/ext/client_config/uri_parser.h", - "src/core/ext/transport/chttp2/alpn/alpn.c", - "src/core/ext/transport/chttp2/alpn/alpn.h", - "src/core/ext/transport/chttp2/client/secure/secure_channel_create.c", - "src/core/ext/transport/chttp2/transport/bin_decoder.c", - "src/core/ext/transport/chttp2/transport/bin_decoder.h", - "src/core/ext/transport/chttp2/transport/bin_encoder.c", - "src/core/ext/transport/chttp2/transport/bin_encoder.h", - "src/core/ext/transport/chttp2/transport/chttp2_plugin.c", - "src/core/ext/transport/chttp2/transport/chttp2_transport.c", - "src/core/ext/transport/chttp2/transport/chttp2_transport.h", - "src/core/ext/transport/chttp2/transport/frame.h", - "src/core/ext/transport/chttp2/transport/frame_data.c", - "src/core/ext/transport/chttp2/transport/frame_data.h", - "src/core/ext/transport/chttp2/transport/frame_goaway.c", - "src/core/ext/transport/chttp2/transport/frame_goaway.h", - "src/core/ext/transport/chttp2/transport/frame_ping.c", - "src/core/ext/transport/chttp2/transport/frame_ping.h", - "src/core/ext/transport/chttp2/transport/frame_rst_stream.c", - "src/core/ext/transport/chttp2/transport/frame_rst_stream.h", - "src/core/ext/transport/chttp2/transport/frame_settings.c", - "src/core/ext/transport/chttp2/transport/frame_settings.h", - "src/core/ext/transport/chttp2/transport/frame_window_update.c", - "src/core/ext/transport/chttp2/transport/frame_window_update.h", - "src/core/ext/transport/chttp2/transport/hpack_encoder.c", - "src/core/ext/transport/chttp2/transport/hpack_encoder.h", - "src/core/ext/transport/chttp2/transport/hpack_parser.c", - "src/core/ext/transport/chttp2/transport/hpack_parser.h", - "src/core/ext/transport/chttp2/transport/hpack_table.c", - "src/core/ext/transport/chttp2/transport/hpack_table.h", - "src/core/ext/transport/chttp2/transport/http2_errors.h", - "src/core/ext/transport/chttp2/transport/huffsyms.c", - "src/core/ext/transport/chttp2/transport/huffsyms.h", - "src/core/ext/transport/chttp2/transport/incoming_metadata.c", - "src/core/ext/transport/chttp2/transport/incoming_metadata.h", - "src/core/ext/transport/chttp2/transport/internal.h", - "src/core/ext/transport/chttp2/transport/parsing.c", - "src/core/ext/transport/chttp2/transport/status_conversion.c", - "src/core/ext/transport/chttp2/transport/status_conversion.h", - "src/core/ext/transport/chttp2/transport/stream_lists.c", - "src/core/ext/transport/chttp2/transport/stream_map.c", - "src/core/ext/transport/chttp2/transport/stream_map.h", - "src/core/ext/transport/chttp2/transport/timeout_encoding.c", - "src/core/ext/transport/chttp2/transport/timeout_encoding.h", - "src/core/ext/transport/chttp2/transport/varint.c", - "src/core/ext/transport/chttp2/transport/varint.h", - "src/core/ext/transport/chttp2/transport/writing.c", - "src/core/ext/transport/cronet/client/secure/cronet_channel_create.c", - "src/core/ext/transport/cronet/transport/cronet_api_dummy.c", - "src/core/ext/transport/cronet/transport/cronet_transport.c", - "src/core/lib/channel/channel_args.c", - "src/core/lib/channel/channel_args.h", - "src/core/lib/channel/channel_stack.c", - "src/core/lib/channel/channel_stack.h", - "src/core/lib/channel/channel_stack_builder.c", - "src/core/lib/channel/channel_stack_builder.h", - "src/core/lib/channel/compress_filter.c", - "src/core/lib/channel/compress_filter.h", - "src/core/lib/channel/connected_channel.c", - "src/core/lib/channel/connected_channel.h", - "src/core/lib/channel/context.h", - "src/core/lib/channel/http_client_filter.c", - "src/core/lib/channel/http_client_filter.h", - "src/core/lib/channel/http_server_filter.c", - "src/core/lib/channel/http_server_filter.h", - "src/core/lib/compression/algorithm_metadata.h", - "src/core/lib/compression/compression.c", - "src/core/lib/compression/message_compress.c", - "src/core/lib/compression/message_compress.h", - "src/core/lib/debug/trace.c", - "src/core/lib/debug/trace.h", - "src/core/lib/http/format_request.c", - "src/core/lib/http/format_request.h", - "src/core/lib/http/httpcli.c", - "src/core/lib/http/httpcli.h", - "src/core/lib/http/httpcli_security_connector.c", - "src/core/lib/http/parser.c", - "src/core/lib/http/parser.h", - "src/core/lib/iomgr/closure.c", - "src/core/lib/iomgr/closure.h", - "src/core/lib/iomgr/endpoint.c", - "src/core/lib/iomgr/endpoint.h", - "src/core/lib/iomgr/endpoint_pair.h", - "src/core/lib/iomgr/endpoint_pair_posix.c", - "src/core/lib/iomgr/endpoint_pair_windows.c", - "src/core/lib/iomgr/error.c", - "src/core/lib/iomgr/error.h", - "src/core/lib/iomgr/ev_epoll_linux.c", - "src/core/lib/iomgr/ev_epoll_linux.h", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.c", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.h", - "src/core/lib/iomgr/ev_poll_posix.c", - "src/core/lib/iomgr/ev_poll_posix.h", - "src/core/lib/iomgr/ev_posix.c", - "src/core/lib/iomgr/ev_posix.h", - "src/core/lib/iomgr/exec_ctx.c", - "src/core/lib/iomgr/exec_ctx.h", - "src/core/lib/iomgr/executor.c", - "src/core/lib/iomgr/executor.h", - "src/core/lib/iomgr/iocp_windows.c", - "src/core/lib/iomgr/iocp_windows.h", - "src/core/lib/iomgr/iomgr.c", - "src/core/lib/iomgr/iomgr.h", - "src/core/lib/iomgr/iomgr_internal.h", - "src/core/lib/iomgr/iomgr_posix.c", - "src/core/lib/iomgr/iomgr_posix.h", - "src/core/lib/iomgr/iomgr_windows.c", - "src/core/lib/iomgr/load_file.c", - "src/core/lib/iomgr/load_file.h", - "src/core/lib/iomgr/network_status_tracker.c", - "src/core/lib/iomgr/network_status_tracker.h", - "src/core/lib/iomgr/polling_entity.c", - "src/core/lib/iomgr/polling_entity.h", - "src/core/lib/iomgr/pollset.h", - "src/core/lib/iomgr/pollset_set.h", - "src/core/lib/iomgr/pollset_set_windows.c", - "src/core/lib/iomgr/pollset_set_windows.h", - "src/core/lib/iomgr/pollset_windows.c", - "src/core/lib/iomgr/pollset_windows.h", - "src/core/lib/iomgr/resolve_address.h", - "src/core/lib/iomgr/resolve_address_posix.c", - "src/core/lib/iomgr/resolve_address_windows.c", - "src/core/lib/iomgr/sockaddr.h", - "src/core/lib/iomgr/sockaddr_posix.h", - "src/core/lib/iomgr/sockaddr_utils.c", - "src/core/lib/iomgr/sockaddr_utils.h", - "src/core/lib/iomgr/sockaddr_windows.h", - "src/core/lib/iomgr/socket_utils_common_posix.c", - "src/core/lib/iomgr/socket_utils_linux.c", - "src/core/lib/iomgr/socket_utils_posix.c", - "src/core/lib/iomgr/socket_utils_posix.h", - "src/core/lib/iomgr/socket_windows.c", - "src/core/lib/iomgr/socket_windows.h", - "src/core/lib/iomgr/tcp_client.h", - "src/core/lib/iomgr/tcp_client_posix.c", - "src/core/lib/iomgr/tcp_client_windows.c", - "src/core/lib/iomgr/tcp_posix.c", - "src/core/lib/iomgr/tcp_posix.h", - "src/core/lib/iomgr/tcp_server.h", - "src/core/lib/iomgr/tcp_server_posix.c", - "src/core/lib/iomgr/tcp_server_windows.c", - "src/core/lib/iomgr/tcp_windows.c", - "src/core/lib/iomgr/tcp_windows.h", - "src/core/lib/iomgr/time_averaged_stats.c", - "src/core/lib/iomgr/time_averaged_stats.h", - "src/core/lib/iomgr/timer.c", - "src/core/lib/iomgr/timer.h", - "src/core/lib/iomgr/timer_heap.c", - "src/core/lib/iomgr/timer_heap.h", - "src/core/lib/iomgr/udp_server.c", - "src/core/lib/iomgr/udp_server.h", - "src/core/lib/iomgr/unix_sockets_posix.c", - "src/core/lib/iomgr/unix_sockets_posix.h", - "src/core/lib/iomgr/unix_sockets_posix_noop.c", - "src/core/lib/iomgr/wakeup_fd_eventfd.c", - "src/core/lib/iomgr/wakeup_fd_nospecial.c", - "src/core/lib/iomgr/wakeup_fd_pipe.c", - "src/core/lib/iomgr/wakeup_fd_pipe.h", - "src/core/lib/iomgr/wakeup_fd_posix.c", - "src/core/lib/iomgr/wakeup_fd_posix.h", - "src/core/lib/iomgr/workqueue.h", - "src/core/lib/iomgr/workqueue_posix.c", - "src/core/lib/iomgr/workqueue_posix.h", - "src/core/lib/iomgr/workqueue_windows.c", - "src/core/lib/iomgr/workqueue_windows.h", - "src/core/lib/json/json.c", - "src/core/lib/json/json.h", - "src/core/lib/json/json_common.h", - "src/core/lib/json/json_reader.c", - "src/core/lib/json/json_reader.h", - "src/core/lib/json/json_string.c", - "src/core/lib/json/json_writer.c", - "src/core/lib/json/json_writer.h", - "src/core/lib/security/context/security_context.c", - "src/core/lib/security/context/security_context.h", - "src/core/lib/security/credentials/composite/composite_credentials.c", - "src/core/lib/security/credentials/composite/composite_credentials.h", - "src/core/lib/security/credentials/credentials.c", - "src/core/lib/security/credentials/credentials.h", - "src/core/lib/security/credentials/credentials_metadata.c", - "src/core/lib/security/credentials/fake/fake_credentials.c", - "src/core/lib/security/credentials/fake/fake_credentials.h", - "src/core/lib/security/credentials/google_default/credentials_posix.c", - "src/core/lib/security/credentials/google_default/credentials_windows.c", - "src/core/lib/security/credentials/google_default/google_default_credentials.c", - "src/core/lib/security/credentials/google_default/google_default_credentials.h", - "src/core/lib/security/credentials/iam/iam_credentials.c", - "src/core/lib/security/credentials/iam/iam_credentials.h", - "src/core/lib/security/credentials/jwt/json_token.c", - "src/core/lib/security/credentials/jwt/json_token.h", - "src/core/lib/security/credentials/jwt/jwt_credentials.c", - "src/core/lib/security/credentials/jwt/jwt_credentials.h", - "src/core/lib/security/credentials/jwt/jwt_verifier.c", - "src/core/lib/security/credentials/jwt/jwt_verifier.h", - "src/core/lib/security/credentials/oauth2/oauth2_credentials.c", - "src/core/lib/security/credentials/oauth2/oauth2_credentials.h", - "src/core/lib/security/credentials/plugin/plugin_credentials.c", - "src/core/lib/security/credentials/plugin/plugin_credentials.h", - "src/core/lib/security/credentials/ssl/ssl_credentials.c", - "src/core/lib/security/credentials/ssl/ssl_credentials.h", - "src/core/lib/security/transport/auth_filters.h", - "src/core/lib/security/transport/client_auth_filter.c", - "src/core/lib/security/transport/handshake.c", - "src/core/lib/security/transport/handshake.h", - "src/core/lib/security/transport/secure_endpoint.c", - "src/core/lib/security/transport/secure_endpoint.h", - "src/core/lib/security/transport/security_connector.c", - "src/core/lib/security/transport/security_connector.h", - "src/core/lib/security/transport/server_auth_filter.c", - "src/core/lib/security/transport/tsi_error.c", - "src/core/lib/security/transport/tsi_error.h", - "src/core/lib/security/util/b64.c", - "src/core/lib/security/util/b64.h", - "src/core/lib/security/util/json_util.c", - "src/core/lib/security/util/json_util.h", - "src/core/lib/surface/alarm.c", - "src/core/lib/surface/api_trace.c", - "src/core/lib/surface/api_trace.h", - "src/core/lib/surface/byte_buffer.c", - "src/core/lib/surface/byte_buffer_reader.c", - "src/core/lib/surface/call.c", - "src/core/lib/surface/call.h", - "src/core/lib/surface/call_details.c", - "src/core/lib/surface/call_log_batch.c", - "src/core/lib/surface/call_test_only.h", - "src/core/lib/surface/channel.c", - "src/core/lib/surface/channel.h", - "src/core/lib/surface/channel_init.c", - "src/core/lib/surface/channel_init.h", - "src/core/lib/surface/channel_ping.c", - "src/core/lib/surface/channel_stack_type.c", - "src/core/lib/surface/channel_stack_type.h", - "src/core/lib/surface/completion_queue.c", - "src/core/lib/surface/completion_queue.h", - "src/core/lib/surface/event_string.c", - "src/core/lib/surface/event_string.h", - "src/core/lib/surface/init.c", - "src/core/lib/surface/init.h", - "src/core/lib/surface/init_secure.c", - "src/core/lib/surface/lame_client.c", - "src/core/lib/surface/lame_client.h", - "src/core/lib/surface/metadata_array.c", - "src/core/lib/surface/server.c", - "src/core/lib/surface/server.h", - "src/core/lib/surface/validate_metadata.c", - "src/core/lib/surface/version.c", - "src/core/lib/transport/byte_stream.c", - "src/core/lib/transport/byte_stream.h", - "src/core/lib/transport/connectivity_state.c", - "src/core/lib/transport/connectivity_state.h", - "src/core/lib/transport/metadata.c", - "src/core/lib/transport/metadata.h", - "src/core/lib/transport/metadata_batch.c", - "src/core/lib/transport/metadata_batch.h", - "src/core/lib/transport/static_metadata.c", - "src/core/lib/transport/static_metadata.h", - "src/core/lib/transport/transport.c", - "src/core/lib/transport/transport.h", - "src/core/lib/transport/transport_impl.h", - "src/core/lib/transport/transport_op_string.c", - "src/core/lib/tsi/fake_transport_security.c", - "src/core/lib/tsi/fake_transport_security.h", - "src/core/lib/tsi/ssl_transport_security.c", - "src/core/lib/tsi/ssl_transport_security.h", - "src/core/lib/tsi/ssl_types.h", - "src/core/lib/tsi/transport_security.c", - "src/core/lib/tsi/transport_security.h", - "src/core/lib/tsi/transport_security_interface.h", - "src/core/plugin_registry/grpc_cronet_plugin_registry.c", - "third_party/nanopb/pb.h", - "third_party/nanopb/pb_decode.h", - "third_party/nanopb/pb_encode.h", - "third_party/objective_c/Cronet/cronet_c_for_grpc.h", - ], - hdrs = [ - "include/grpc/byte_buffer.h", - "include/grpc/byte_buffer_reader.h", - "include/grpc/compression.h", - "include/grpc/grpc.h", - "include/grpc/grpc_cronet.h", - "include/grpc/grpc_posix.h", - "include/grpc/grpc_security.h", - "include/grpc/grpc_security_constants.h", - "include/grpc/impl/codegen/alloc.h", - "include/grpc/impl/codegen/atm.h", - "include/grpc/impl/codegen/atm_gcc_atomic.h", - "include/grpc/impl/codegen/atm_gcc_sync.h", - "include/grpc/impl/codegen/atm_windows.h", - "include/grpc/impl/codegen/byte_buffer.h", - "include/grpc/impl/codegen/byte_buffer_reader.h", - "include/grpc/impl/codegen/compression_types.h", - "include/grpc/impl/codegen/connectivity_state.h", - "include/grpc/impl/codegen/grpc_types.h", - "include/grpc/impl/codegen/log.h", - "include/grpc/impl/codegen/port_platform.h", - "include/grpc/impl/codegen/propagation_bits.h", - "include/grpc/impl/codegen/slice.h", - "include/grpc/impl/codegen/slice_buffer.h", - "include/grpc/impl/codegen/status.h", - "include/grpc/impl/codegen/sync.h", - "include/grpc/impl/codegen/sync_generic.h", - "include/grpc/impl/codegen/sync_posix.h", - "include/grpc/impl/codegen/sync_windows.h", - "include/grpc/impl/codegen/time.h", - "include/grpc/status.h", - ], - includes = [ - ".", - "include", - ], - deps = [ - ":gpr", - "//external:libssl", - ], -) - -cc_library( - name = "grpc_unsecure", - srcs = [ - "src/core/ext/census/aggregation.h", - "src/core/ext/census/census_interface.h", - "src/core/ext/census/census_rpc_stats.h", - "src/core/ext/census/context.c", - "src/core/ext/census/gen/census.pb.c", - "src/core/ext/census/gen/census.pb.h", - "src/core/ext/census/grpc_context.c", - "src/core/ext/census/grpc_filter.c", - "src/core/ext/census/grpc_filter.h", - "src/core/ext/census/grpc_plugin.c", - "src/core/ext/census/initialize.c", - "src/core/ext/census/mlog.c", - "src/core/ext/census/mlog.h", - "src/core/ext/census/operation.c", - "src/core/ext/census/placeholders.c", - "src/core/ext/census/rpc_metric_id.h", - "src/core/ext/census/tracing.c", - "src/core/ext/client_config/channel_connectivity.c", - "src/core/ext/client_config/client_channel.c", - "src/core/ext/client_config/client_channel.h", - "src/core/ext/client_config/client_channel_factory.c", - "src/core/ext/client_config/client_channel_factory.h", - "src/core/ext/client_config/client_config.c", - "src/core/ext/client_config/client_config.h", - "src/core/ext/client_config/client_config_plugin.c", - "src/core/ext/client_config/connector.c", - "src/core/ext/client_config/connector.h", - "src/core/ext/client_config/default_initial_connect_string.c", - "src/core/ext/client_config/initial_connect_string.c", - "src/core/ext/client_config/initial_connect_string.h", - "src/core/ext/client_config/lb_policy.c", - "src/core/ext/client_config/lb_policy.h", - "src/core/ext/client_config/lb_policy_factory.c", - "src/core/ext/client_config/lb_policy_factory.h", - "src/core/ext/client_config/lb_policy_registry.c", - "src/core/ext/client_config/lb_policy_registry.h", - "src/core/ext/client_config/parse_address.c", - "src/core/ext/client_config/parse_address.h", - "src/core/ext/client_config/resolver.c", - "src/core/ext/client_config/resolver.h", - "src/core/ext/client_config/resolver_factory.c", - "src/core/ext/client_config/resolver_factory.h", - "src/core/ext/client_config/resolver_registry.c", - "src/core/ext/client_config/resolver_registry.h", - "src/core/ext/client_config/subchannel.c", - "src/core/ext/client_config/subchannel.h", - "src/core/ext/client_config/subchannel_call_holder.c", - "src/core/ext/client_config/subchannel_call_holder.h", - "src/core/ext/client_config/subchannel_index.c", - "src/core/ext/client_config/subchannel_index.h", - "src/core/ext/client_config/uri_parser.c", - "src/core/ext/client_config/uri_parser.h", - "src/core/ext/lb_policy/grpclb/load_balancer_api.c", - "src/core/ext/lb_policy/grpclb/load_balancer_api.h", - "src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c", - "src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.h", - "src/core/ext/lb_policy/pick_first/pick_first.c", - "src/core/ext/lb_policy/round_robin/round_robin.c", - "src/core/ext/load_reporting/load_reporting.c", - "src/core/ext/load_reporting/load_reporting.h", - "src/core/ext/load_reporting/load_reporting_filter.c", - "src/core/ext/load_reporting/load_reporting_filter.h", - "src/core/ext/resolver/dns/native/dns_resolver.c", - "src/core/ext/resolver/sockaddr/sockaddr_resolver.c", - "src/core/ext/transport/chttp2/alpn/alpn.c", - "src/core/ext/transport/chttp2/alpn/alpn.h", - "src/core/ext/transport/chttp2/client/insecure/channel_create.c", - "src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c", - "src/core/ext/transport/chttp2/server/insecure/server_chttp2.c", - "src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c", - "src/core/ext/transport/chttp2/transport/bin_decoder.c", - "src/core/ext/transport/chttp2/transport/bin_decoder.h", - "src/core/ext/transport/chttp2/transport/bin_encoder.c", - "src/core/ext/transport/chttp2/transport/bin_encoder.h", - "src/core/ext/transport/chttp2/transport/chttp2_plugin.c", - "src/core/ext/transport/chttp2/transport/chttp2_transport.c", - "src/core/ext/transport/chttp2/transport/chttp2_transport.h", - "src/core/ext/transport/chttp2/transport/frame.h", - "src/core/ext/transport/chttp2/transport/frame_data.c", - "src/core/ext/transport/chttp2/transport/frame_data.h", - "src/core/ext/transport/chttp2/transport/frame_goaway.c", - "src/core/ext/transport/chttp2/transport/frame_goaway.h", - "src/core/ext/transport/chttp2/transport/frame_ping.c", - "src/core/ext/transport/chttp2/transport/frame_ping.h", - "src/core/ext/transport/chttp2/transport/frame_rst_stream.c", - "src/core/ext/transport/chttp2/transport/frame_rst_stream.h", - "src/core/ext/transport/chttp2/transport/frame_settings.c", - "src/core/ext/transport/chttp2/transport/frame_settings.h", - "src/core/ext/transport/chttp2/transport/frame_window_update.c", - "src/core/ext/transport/chttp2/transport/frame_window_update.h", - "src/core/ext/transport/chttp2/transport/hpack_encoder.c", - "src/core/ext/transport/chttp2/transport/hpack_encoder.h", - "src/core/ext/transport/chttp2/transport/hpack_parser.c", - "src/core/ext/transport/chttp2/transport/hpack_parser.h", - "src/core/ext/transport/chttp2/transport/hpack_table.c", - "src/core/ext/transport/chttp2/transport/hpack_table.h", - "src/core/ext/transport/chttp2/transport/http2_errors.h", - "src/core/ext/transport/chttp2/transport/huffsyms.c", - "src/core/ext/transport/chttp2/transport/huffsyms.h", - "src/core/ext/transport/chttp2/transport/incoming_metadata.c", - "src/core/ext/transport/chttp2/transport/incoming_metadata.h", - "src/core/ext/transport/chttp2/transport/internal.h", - "src/core/ext/transport/chttp2/transport/parsing.c", - "src/core/ext/transport/chttp2/transport/status_conversion.c", - "src/core/ext/transport/chttp2/transport/status_conversion.h", - "src/core/ext/transport/chttp2/transport/stream_lists.c", - "src/core/ext/transport/chttp2/transport/stream_map.c", - "src/core/ext/transport/chttp2/transport/stream_map.h", - "src/core/ext/transport/chttp2/transport/timeout_encoding.c", - "src/core/ext/transport/chttp2/transport/timeout_encoding.h", - "src/core/ext/transport/chttp2/transport/varint.c", - "src/core/ext/transport/chttp2/transport/varint.h", - "src/core/ext/transport/chttp2/transport/writing.c", - "src/core/lib/channel/channel_args.c", - "src/core/lib/channel/channel_args.h", - "src/core/lib/channel/channel_stack.c", - "src/core/lib/channel/channel_stack.h", - "src/core/lib/channel/channel_stack_builder.c", - "src/core/lib/channel/channel_stack_builder.h", - "src/core/lib/channel/compress_filter.c", - "src/core/lib/channel/compress_filter.h", - "src/core/lib/channel/connected_channel.c", - "src/core/lib/channel/connected_channel.h", - "src/core/lib/channel/context.h", - "src/core/lib/channel/http_client_filter.c", - "src/core/lib/channel/http_client_filter.h", - "src/core/lib/channel/http_server_filter.c", - "src/core/lib/channel/http_server_filter.h", - "src/core/lib/compression/algorithm_metadata.h", - "src/core/lib/compression/compression.c", - "src/core/lib/compression/message_compress.c", - "src/core/lib/compression/message_compress.h", - "src/core/lib/debug/trace.c", - "src/core/lib/debug/trace.h", - "src/core/lib/http/format_request.c", - "src/core/lib/http/format_request.h", - "src/core/lib/http/httpcli.c", - "src/core/lib/http/httpcli.h", - "src/core/lib/http/parser.c", - "src/core/lib/http/parser.h", - "src/core/lib/iomgr/closure.c", - "src/core/lib/iomgr/closure.h", - "src/core/lib/iomgr/endpoint.c", - "src/core/lib/iomgr/endpoint.h", - "src/core/lib/iomgr/endpoint_pair.h", - "src/core/lib/iomgr/endpoint_pair_posix.c", - "src/core/lib/iomgr/endpoint_pair_windows.c", - "src/core/lib/iomgr/error.c", - "src/core/lib/iomgr/error.h", - "src/core/lib/iomgr/ev_epoll_linux.c", - "src/core/lib/iomgr/ev_epoll_linux.h", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.c", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.h", - "src/core/lib/iomgr/ev_poll_posix.c", - "src/core/lib/iomgr/ev_poll_posix.h", - "src/core/lib/iomgr/ev_posix.c", - "src/core/lib/iomgr/ev_posix.h", - "src/core/lib/iomgr/exec_ctx.c", - "src/core/lib/iomgr/exec_ctx.h", - "src/core/lib/iomgr/executor.c", - "src/core/lib/iomgr/executor.h", - "src/core/lib/iomgr/iocp_windows.c", - "src/core/lib/iomgr/iocp_windows.h", - "src/core/lib/iomgr/iomgr.c", - "src/core/lib/iomgr/iomgr.h", - "src/core/lib/iomgr/iomgr_internal.h", - "src/core/lib/iomgr/iomgr_posix.c", - "src/core/lib/iomgr/iomgr_posix.h", - "src/core/lib/iomgr/iomgr_windows.c", - "src/core/lib/iomgr/load_file.c", - "src/core/lib/iomgr/load_file.h", - "src/core/lib/iomgr/network_status_tracker.c", - "src/core/lib/iomgr/network_status_tracker.h", - "src/core/lib/iomgr/polling_entity.c", - "src/core/lib/iomgr/polling_entity.h", - "src/core/lib/iomgr/pollset.h", - "src/core/lib/iomgr/pollset_set.h", - "src/core/lib/iomgr/pollset_set_windows.c", - "src/core/lib/iomgr/pollset_set_windows.h", - "src/core/lib/iomgr/pollset_windows.c", - "src/core/lib/iomgr/pollset_windows.h", - "src/core/lib/iomgr/resolve_address.h", - "src/core/lib/iomgr/resolve_address_posix.c", - "src/core/lib/iomgr/resolve_address_windows.c", - "src/core/lib/iomgr/sockaddr.h", - "src/core/lib/iomgr/sockaddr_posix.h", - "src/core/lib/iomgr/sockaddr_utils.c", - "src/core/lib/iomgr/sockaddr_utils.h", - "src/core/lib/iomgr/sockaddr_windows.h", - "src/core/lib/iomgr/socket_utils_common_posix.c", - "src/core/lib/iomgr/socket_utils_linux.c", - "src/core/lib/iomgr/socket_utils_posix.c", - "src/core/lib/iomgr/socket_utils_posix.h", - "src/core/lib/iomgr/socket_windows.c", - "src/core/lib/iomgr/socket_windows.h", - "src/core/lib/iomgr/tcp_client.h", - "src/core/lib/iomgr/tcp_client_posix.c", - "src/core/lib/iomgr/tcp_client_windows.c", - "src/core/lib/iomgr/tcp_posix.c", - "src/core/lib/iomgr/tcp_posix.h", - "src/core/lib/iomgr/tcp_server.h", - "src/core/lib/iomgr/tcp_server_posix.c", - "src/core/lib/iomgr/tcp_server_windows.c", - "src/core/lib/iomgr/tcp_windows.c", - "src/core/lib/iomgr/tcp_windows.h", - "src/core/lib/iomgr/time_averaged_stats.c", - "src/core/lib/iomgr/time_averaged_stats.h", - "src/core/lib/iomgr/timer.c", - "src/core/lib/iomgr/timer.h", - "src/core/lib/iomgr/timer_heap.c", - "src/core/lib/iomgr/timer_heap.h", - "src/core/lib/iomgr/udp_server.c", - "src/core/lib/iomgr/udp_server.h", - "src/core/lib/iomgr/unix_sockets_posix.c", - "src/core/lib/iomgr/unix_sockets_posix.h", - "src/core/lib/iomgr/unix_sockets_posix_noop.c", - "src/core/lib/iomgr/wakeup_fd_eventfd.c", - "src/core/lib/iomgr/wakeup_fd_nospecial.c", - "src/core/lib/iomgr/wakeup_fd_pipe.c", - "src/core/lib/iomgr/wakeup_fd_pipe.h", - "src/core/lib/iomgr/wakeup_fd_posix.c", - "src/core/lib/iomgr/wakeup_fd_posix.h", - "src/core/lib/iomgr/workqueue.h", - "src/core/lib/iomgr/workqueue_posix.c", - "src/core/lib/iomgr/workqueue_posix.h", - "src/core/lib/iomgr/workqueue_windows.c", - "src/core/lib/iomgr/workqueue_windows.h", - "src/core/lib/json/json.c", - "src/core/lib/json/json.h", - "src/core/lib/json/json_common.h", - "src/core/lib/json/json_reader.c", - "src/core/lib/json/json_reader.h", - "src/core/lib/json/json_string.c", - "src/core/lib/json/json_writer.c", - "src/core/lib/json/json_writer.h", - "src/core/lib/surface/alarm.c", - "src/core/lib/surface/api_trace.c", - "src/core/lib/surface/api_trace.h", - "src/core/lib/surface/byte_buffer.c", - "src/core/lib/surface/byte_buffer_reader.c", - "src/core/lib/surface/call.c", - "src/core/lib/surface/call.h", - "src/core/lib/surface/call_details.c", - "src/core/lib/surface/call_log_batch.c", - "src/core/lib/surface/call_test_only.h", - "src/core/lib/surface/channel.c", - "src/core/lib/surface/channel.h", - "src/core/lib/surface/channel_init.c", - "src/core/lib/surface/channel_init.h", - "src/core/lib/surface/channel_ping.c", - "src/core/lib/surface/channel_stack_type.c", - "src/core/lib/surface/channel_stack_type.h", - "src/core/lib/surface/completion_queue.c", - "src/core/lib/surface/completion_queue.h", - "src/core/lib/surface/event_string.c", - "src/core/lib/surface/event_string.h", - "src/core/lib/surface/init.c", - "src/core/lib/surface/init.h", - "src/core/lib/surface/init_unsecure.c", - "src/core/lib/surface/lame_client.c", - "src/core/lib/surface/lame_client.h", - "src/core/lib/surface/metadata_array.c", - "src/core/lib/surface/server.c", - "src/core/lib/surface/server.h", - "src/core/lib/surface/validate_metadata.c", - "src/core/lib/surface/version.c", - "src/core/lib/transport/byte_stream.c", - "src/core/lib/transport/byte_stream.h", - "src/core/lib/transport/connectivity_state.c", - "src/core/lib/transport/connectivity_state.h", - "src/core/lib/transport/metadata.c", - "src/core/lib/transport/metadata.h", - "src/core/lib/transport/metadata_batch.c", - "src/core/lib/transport/metadata_batch.h", - "src/core/lib/transport/static_metadata.c", - "src/core/lib/transport/static_metadata.h", - "src/core/lib/transport/transport.c", - "src/core/lib/transport/transport.h", - "src/core/lib/transport/transport_impl.h", - "src/core/lib/transport/transport_op_string.c", - "src/core/plugin_registry/grpc_unsecure_plugin_registry.c", - "third_party/nanopb/pb.h", - "third_party/nanopb/pb_decode.h", - "third_party/nanopb/pb_encode.h", - ], - hdrs = [ - "include/grpc/byte_buffer.h", - "include/grpc/byte_buffer_reader.h", - "include/grpc/census.h", - "include/grpc/compression.h", - "include/grpc/grpc.h", - "include/grpc/grpc_posix.h", - "include/grpc/grpc_security_constants.h", - "include/grpc/impl/codegen/alloc.h", - "include/grpc/impl/codegen/atm.h", - "include/grpc/impl/codegen/atm_gcc_atomic.h", - "include/grpc/impl/codegen/atm_gcc_sync.h", - "include/grpc/impl/codegen/atm_windows.h", - "include/grpc/impl/codegen/byte_buffer.h", - "include/grpc/impl/codegen/byte_buffer_reader.h", - "include/grpc/impl/codegen/compression_types.h", - "include/grpc/impl/codegen/connectivity_state.h", - "include/grpc/impl/codegen/grpc_types.h", - "include/grpc/impl/codegen/log.h", - "include/grpc/impl/codegen/port_platform.h", - "include/grpc/impl/codegen/propagation_bits.h", - "include/grpc/impl/codegen/slice.h", - "include/grpc/impl/codegen/slice_buffer.h", - "include/grpc/impl/codegen/status.h", - "include/grpc/impl/codegen/sync.h", - "include/grpc/impl/codegen/sync_generic.h", - "include/grpc/impl/codegen/sync_posix.h", - "include/grpc/impl/codegen/sync_windows.h", - "include/grpc/impl/codegen/time.h", - "include/grpc/status.h", - ], - copts = [ - "-std=gnu99", - ], - includes = [ - ".", - "include", - ], - deps = [ - ":gpr", - "//external:nanopb", - "//external:zlib", - ], -) - -cc_library( - name = "grpc++", - srcs = [ - "include/grpc++/impl/codegen/core_codegen.h", - "src/core/lib/channel/channel_args.c", - "src/core/lib/channel/channel_args.h", - "src/core/lib/channel/channel_stack.c", - "src/core/lib/channel/channel_stack.h", - "src/core/lib/channel/channel_stack_builder.c", - "src/core/lib/channel/channel_stack_builder.h", - "src/core/lib/channel/compress_filter.c", - "src/core/lib/channel/compress_filter.h", - "src/core/lib/channel/connected_channel.c", - "src/core/lib/channel/connected_channel.h", - "src/core/lib/channel/context.h", - "src/core/lib/channel/http_client_filter.c", - "src/core/lib/channel/http_client_filter.h", - "src/core/lib/channel/http_server_filter.c", - "src/core/lib/channel/http_server_filter.h", - "src/core/lib/compression/algorithm_metadata.h", - "src/core/lib/compression/compression.c", - "src/core/lib/compression/message_compress.c", - "src/core/lib/compression/message_compress.h", - "src/core/lib/debug/trace.c", - "src/core/lib/debug/trace.h", - "src/core/lib/http/format_request.c", - "src/core/lib/http/format_request.h", - "src/core/lib/http/httpcli.c", - "src/core/lib/http/httpcli.h", - "src/core/lib/http/parser.c", - "src/core/lib/http/parser.h", - "src/core/lib/iomgr/closure.c", - "src/core/lib/iomgr/closure.h", - "src/core/lib/iomgr/endpoint.c", - "src/core/lib/iomgr/endpoint.h", - "src/core/lib/iomgr/endpoint_pair.h", - "src/core/lib/iomgr/endpoint_pair_posix.c", - "src/core/lib/iomgr/endpoint_pair_windows.c", - "src/core/lib/iomgr/error.c", - "src/core/lib/iomgr/error.h", - "src/core/lib/iomgr/ev_epoll_linux.c", - "src/core/lib/iomgr/ev_epoll_linux.h", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.c", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.h", - "src/core/lib/iomgr/ev_poll_posix.c", - "src/core/lib/iomgr/ev_poll_posix.h", - "src/core/lib/iomgr/ev_posix.c", - "src/core/lib/iomgr/ev_posix.h", - "src/core/lib/iomgr/exec_ctx.c", - "src/core/lib/iomgr/exec_ctx.h", - "src/core/lib/iomgr/executor.c", - "src/core/lib/iomgr/executor.h", - "src/core/lib/iomgr/iocp_windows.c", - "src/core/lib/iomgr/iocp_windows.h", - "src/core/lib/iomgr/iomgr.c", - "src/core/lib/iomgr/iomgr.h", - "src/core/lib/iomgr/iomgr_internal.h", - "src/core/lib/iomgr/iomgr_posix.c", - "src/core/lib/iomgr/iomgr_posix.h", - "src/core/lib/iomgr/iomgr_windows.c", - "src/core/lib/iomgr/load_file.c", - "src/core/lib/iomgr/load_file.h", - "src/core/lib/iomgr/network_status_tracker.c", - "src/core/lib/iomgr/network_status_tracker.h", - "src/core/lib/iomgr/polling_entity.c", - "src/core/lib/iomgr/polling_entity.h", - "src/core/lib/iomgr/pollset.h", - "src/core/lib/iomgr/pollset_set.h", - "src/core/lib/iomgr/pollset_set_windows.c", - "src/core/lib/iomgr/pollset_set_windows.h", - "src/core/lib/iomgr/pollset_windows.c", - "src/core/lib/iomgr/pollset_windows.h", - "src/core/lib/iomgr/resolve_address.h", - "src/core/lib/iomgr/resolve_address_posix.c", - "src/core/lib/iomgr/resolve_address_windows.c", - "src/core/lib/iomgr/sockaddr.h", - "src/core/lib/iomgr/sockaddr_posix.h", - "src/core/lib/iomgr/sockaddr_utils.c", - "src/core/lib/iomgr/sockaddr_utils.h", - "src/core/lib/iomgr/sockaddr_windows.h", - "src/core/lib/iomgr/socket_utils_common_posix.c", - "src/core/lib/iomgr/socket_utils_linux.c", - "src/core/lib/iomgr/socket_utils_posix.c", - "src/core/lib/iomgr/socket_utils_posix.h", - "src/core/lib/iomgr/socket_windows.c", - "src/core/lib/iomgr/socket_windows.h", - "src/core/lib/iomgr/tcp_client.h", - "src/core/lib/iomgr/tcp_client_posix.c", - "src/core/lib/iomgr/tcp_client_windows.c", - "src/core/lib/iomgr/tcp_posix.c", - "src/core/lib/iomgr/tcp_posix.h", - "src/core/lib/iomgr/tcp_server.h", - "src/core/lib/iomgr/tcp_server_posix.c", - "src/core/lib/iomgr/tcp_server_windows.c", - "src/core/lib/iomgr/tcp_windows.c", - "src/core/lib/iomgr/tcp_windows.h", - "src/core/lib/iomgr/time_averaged_stats.c", - "src/core/lib/iomgr/time_averaged_stats.h", - "src/core/lib/iomgr/timer.c", - "src/core/lib/iomgr/timer.h", - "src/core/lib/iomgr/timer_heap.c", - "src/core/lib/iomgr/timer_heap.h", - "src/core/lib/iomgr/udp_server.c", - "src/core/lib/iomgr/udp_server.h", - "src/core/lib/iomgr/unix_sockets_posix.c", - "src/core/lib/iomgr/unix_sockets_posix.h", - "src/core/lib/iomgr/unix_sockets_posix_noop.c", - "src/core/lib/iomgr/wakeup_fd_eventfd.c", - "src/core/lib/iomgr/wakeup_fd_nospecial.c", - "src/core/lib/iomgr/wakeup_fd_pipe.c", - "src/core/lib/iomgr/wakeup_fd_pipe.h", - "src/core/lib/iomgr/wakeup_fd_posix.c", - "src/core/lib/iomgr/wakeup_fd_posix.h", - "src/core/lib/iomgr/workqueue.h", - "src/core/lib/iomgr/workqueue_posix.c", - "src/core/lib/iomgr/workqueue_posix.h", - "src/core/lib/iomgr/workqueue_windows.c", - "src/core/lib/iomgr/workqueue_windows.h", - "src/core/lib/json/json.c", - "src/core/lib/json/json.h", - "src/core/lib/json/json_common.h", - "src/core/lib/json/json_reader.c", - "src/core/lib/json/json_reader.h", - "src/core/lib/json/json_string.c", - "src/core/lib/json/json_writer.c", - "src/core/lib/json/json_writer.h", - "src/core/lib/surface/alarm.c", - "src/core/lib/surface/api_trace.c", - "src/core/lib/surface/api_trace.h", - "src/core/lib/surface/byte_buffer.c", - "src/core/lib/surface/byte_buffer_reader.c", - "src/core/lib/surface/call.c", - "src/core/lib/surface/call.h", - "src/core/lib/surface/call_details.c", - "src/core/lib/surface/call_log_batch.c", - "src/core/lib/surface/call_test_only.h", - "src/core/lib/surface/channel.c", - "src/core/lib/surface/channel.h", - "src/core/lib/surface/channel_init.c", - "src/core/lib/surface/channel_init.h", - "src/core/lib/surface/channel_ping.c", - "src/core/lib/surface/channel_stack_type.c", - "src/core/lib/surface/channel_stack_type.h", - "src/core/lib/surface/completion_queue.c", - "src/core/lib/surface/completion_queue.h", - "src/core/lib/surface/event_string.c", - "src/core/lib/surface/event_string.h", - "src/core/lib/surface/init.h", - "src/core/lib/surface/lame_client.c", - "src/core/lib/surface/lame_client.h", - "src/core/lib/surface/metadata_array.c", - "src/core/lib/surface/server.c", - "src/core/lib/surface/server.h", - "src/core/lib/surface/validate_metadata.c", - "src/core/lib/surface/version.c", - "src/core/lib/transport/byte_stream.c", - "src/core/lib/transport/byte_stream.h", - "src/core/lib/transport/connectivity_state.c", - "src/core/lib/transport/connectivity_state.h", - "src/core/lib/transport/metadata.c", - "src/core/lib/transport/metadata.h", - "src/core/lib/transport/metadata_batch.c", - "src/core/lib/transport/metadata_batch.h", - "src/core/lib/transport/static_metadata.c", - "src/core/lib/transport/static_metadata.h", - "src/core/lib/transport/transport.c", - "src/core/lib/transport/transport.h", - "src/core/lib/transport/transport_impl.h", - "src/core/lib/transport/transport_op_string.c", - "src/cpp/client/channel.cc", - "src/cpp/client/client_context.cc", - "src/cpp/client/create_channel.cc", - "src/cpp/client/create_channel_internal.cc", - "src/cpp/client/create_channel_internal.h", - "src/cpp/client/create_channel_posix.cc", - "src/cpp/client/credentials.cc", - "src/cpp/client/generic_stub.cc", - "src/cpp/client/insecure_credentials.cc", - "src/cpp/client/secure_credentials.cc", - "src/cpp/client/secure_credentials.h", - "src/cpp/codegen/codegen_init.cc", - "src/cpp/common/auth_property_iterator.cc", - "src/cpp/common/channel_arguments.cc", - "src/cpp/common/completion_queue.cc", - "src/cpp/common/core_codegen.cc", - "src/cpp/common/rpc_method.cc", - "src/cpp/common/secure_auth_context.cc", - "src/cpp/common/secure_auth_context.h", - "src/cpp/common/secure_channel_arguments.cc", - "src/cpp/common/secure_create_auth_context.cc", - "src/cpp/server/async_generic_service.cc", - "src/cpp/server/create_default_thread_pool.cc", - "src/cpp/server/dynamic_thread_pool.cc", - "src/cpp/server/dynamic_thread_pool.h", - "src/cpp/server/insecure_server_credentials.cc", - "src/cpp/server/secure_server_credentials.cc", - "src/cpp/server/secure_server_credentials.h", - "src/cpp/server/server.cc", - "src/cpp/server/server_builder.cc", - "src/cpp/server/server_context.cc", - "src/cpp/server/server_credentials.cc", - "src/cpp/server/server_posix.cc", - "src/cpp/server/thread_pool_interface.h", - "src/cpp/util/byte_buffer.cc", - "src/cpp/util/slice.cc", - "src/cpp/util/status.cc", - "src/cpp/util/string_ref.cc", - "src/cpp/util/time.cc", - ], - hdrs = [ - "include/grpc++/alarm.h", - "include/grpc++/channel.h", - "include/grpc++/client_context.h", - "include/grpc++/completion_queue.h", - "include/grpc++/create_channel.h", - "include/grpc++/create_channel_posix.h", - "include/grpc++/generic/async_generic_service.h", - "include/grpc++/generic/generic_stub.h", - "include/grpc++/grpc++.h", - "include/grpc++/impl/call.h", - "include/grpc++/impl/client_unary_call.h", - "include/grpc++/impl/codegen/async_stream.h", - "include/grpc++/impl/codegen/async_unary_call.h", - "include/grpc++/impl/codegen/call.h", - "include/grpc++/impl/codegen/call_hook.h", - "include/grpc++/impl/codegen/channel_interface.h", - "include/grpc++/impl/codegen/client_context.h", - "include/grpc++/impl/codegen/client_unary_call.h", - "include/grpc++/impl/codegen/completion_queue.h", - "include/grpc++/impl/codegen/completion_queue_tag.h", - "include/grpc++/impl/codegen/config.h", - "include/grpc++/impl/codegen/core_codegen.h", - "include/grpc++/impl/codegen/core_codegen_interface.h", - "include/grpc++/impl/codegen/create_auth_context.h", - "include/grpc++/impl/codegen/grpc_library.h", - "include/grpc++/impl/codegen/method_handler_impl.h", - "include/grpc++/impl/codegen/proto_utils.h", - "include/grpc++/impl/codegen/rpc_method.h", - "include/grpc++/impl/codegen/rpc_service_method.h", - "include/grpc++/impl/codegen/security/auth_context.h", - "include/grpc++/impl/codegen/serialization_traits.h", - "include/grpc++/impl/codegen/server_context.h", - "include/grpc++/impl/codegen/server_interface.h", - "include/grpc++/impl/codegen/service_type.h", - "include/grpc++/impl/codegen/status.h", - "include/grpc++/impl/codegen/status_code_enum.h", - "include/grpc++/impl/codegen/string_ref.h", - "include/grpc++/impl/codegen/stub_options.h", - "include/grpc++/impl/codegen/sync.h", - "include/grpc++/impl/codegen/sync_cxx11.h", - "include/grpc++/impl/codegen/sync_no_cxx11.h", - "include/grpc++/impl/codegen/sync_stream.h", - "include/grpc++/impl/codegen/time.h", - "include/grpc++/impl/grpc_library.h", - "include/grpc++/impl/method_handler_impl.h", - "include/grpc++/impl/rpc_method.h", - "include/grpc++/impl/rpc_service_method.h", - "include/grpc++/impl/serialization_traits.h", - "include/grpc++/impl/server_builder_option.h", - "include/grpc++/impl/server_builder_plugin.h", - "include/grpc++/impl/server_initializer.h", - "include/grpc++/impl/service_type.h", - "include/grpc++/impl/sync.h", - "include/grpc++/impl/sync_cxx11.h", - "include/grpc++/impl/sync_no_cxx11.h", - "include/grpc++/impl/thd.h", - "include/grpc++/impl/thd_cxx11.h", - "include/grpc++/impl/thd_no_cxx11.h", - "include/grpc++/security/auth_context.h", - "include/grpc++/security/auth_metadata_processor.h", - "include/grpc++/security/credentials.h", - "include/grpc++/security/server_credentials.h", - "include/grpc++/server.h", - "include/grpc++/server_builder.h", - "include/grpc++/server_context.h", - "include/grpc++/server_posix.h", - "include/grpc++/support/async_stream.h", - "include/grpc++/support/async_unary_call.h", - "include/grpc++/support/byte_buffer.h", - "include/grpc++/support/channel_arguments.h", - "include/grpc++/support/config.h", - "include/grpc++/support/slice.h", - "include/grpc++/support/status.h", - "include/grpc++/support/status_code_enum.h", - "include/grpc++/support/string_ref.h", - "include/grpc++/support/stub_options.h", - "include/grpc++/support/sync_stream.h", - "include/grpc++/support/time.h", - "include/grpc/byte_buffer.h", - "include/grpc/byte_buffer_reader.h", - "include/grpc/compression.h", - "include/grpc/grpc.h", - "include/grpc/grpc_posix.h", - "include/grpc/grpc_security_constants.h", - "include/grpc/impl/codegen/alloc.h", - "include/grpc/impl/codegen/atm.h", - "include/grpc/impl/codegen/atm_gcc_atomic.h", - "include/grpc/impl/codegen/atm_gcc_sync.h", - "include/grpc/impl/codegen/atm_windows.h", - "include/grpc/impl/codegen/byte_buffer.h", - "include/grpc/impl/codegen/byte_buffer_reader.h", - "include/grpc/impl/codegen/compression_types.h", - "include/grpc/impl/codegen/connectivity_state.h", - "include/grpc/impl/codegen/grpc_types.h", - "include/grpc/impl/codegen/log.h", - "include/grpc/impl/codegen/port_platform.h", - "include/grpc/impl/codegen/propagation_bits.h", - "include/grpc/impl/codegen/slice.h", - "include/grpc/impl/codegen/slice_buffer.h", - "include/grpc/impl/codegen/status.h", - "include/grpc/impl/codegen/sync.h", - "include/grpc/impl/codegen/sync_generic.h", - "include/grpc/impl/codegen/sync_posix.h", - "include/grpc/impl/codegen/sync_windows.h", - "include/grpc/impl/codegen/time.h", - "include/grpc/status.h", - ], - copts = [ - "-std=gnu99", - ], - includes = [ - ".", - "include", - ], - deps = [ - ":gpr", - ":grpc", - "//external:libssl", - "//external:protobuf_clib", - ], -) - -cc_library( - name = "grpc++_reflection", - srcs = [ - "src/cpp/ext/proto_server_reflection.cc", - "src/cpp/ext/proto_server_reflection.h", - "src/cpp/ext/proto_server_reflection_plugin.cc", - "src/cpp/ext/reflection.grpc.pb.cc", - "src/cpp/ext/reflection.pb.cc", - ], - hdrs = [ - "include/grpc++/ext/proto_server_reflection_plugin.h", - "include/grpc++/ext/reflection.grpc.pb.h", - "include/grpc++/ext/reflection.pb.h", - "include/grpc++/impl/codegen/async_stream.h", - "include/grpc++/impl/codegen/async_unary_call.h", - "include/grpc++/impl/codegen/call.h", - "include/grpc++/impl/codegen/call_hook.h", - "include/grpc++/impl/codegen/channel_interface.h", - "include/grpc++/impl/codegen/client_context.h", - "include/grpc++/impl/codegen/client_unary_call.h", - "include/grpc++/impl/codegen/completion_queue.h", - "include/grpc++/impl/codegen/completion_queue_tag.h", - "include/grpc++/impl/codegen/config.h", - "include/grpc++/impl/codegen/config_protobuf.h", - "include/grpc++/impl/codegen/core_codegen_interface.h", - "include/grpc++/impl/codegen/create_auth_context.h", - "include/grpc++/impl/codegen/grpc_library.h", - "include/grpc++/impl/codegen/method_handler_impl.h", - "include/grpc++/impl/codegen/proto_utils.h", - "include/grpc++/impl/codegen/rpc_method.h", - "include/grpc++/impl/codegen/rpc_service_method.h", - "include/grpc++/impl/codegen/security/auth_context.h", - "include/grpc++/impl/codegen/serialization_traits.h", - "include/grpc++/impl/codegen/server_context.h", - "include/grpc++/impl/codegen/server_interface.h", - "include/grpc++/impl/codegen/service_type.h", - "include/grpc++/impl/codegen/status.h", - "include/grpc++/impl/codegen/status_code_enum.h", - "include/grpc++/impl/codegen/string_ref.h", - "include/grpc++/impl/codegen/stub_options.h", - "include/grpc++/impl/codegen/sync.h", - "include/grpc++/impl/codegen/sync_cxx11.h", - "include/grpc++/impl/codegen/sync_no_cxx11.h", - "include/grpc++/impl/codegen/sync_stream.h", - "include/grpc++/impl/codegen/time.h", - "include/grpc/impl/codegen/alloc.h", - "include/grpc/impl/codegen/atm.h", - "include/grpc/impl/codegen/atm_gcc_atomic.h", - "include/grpc/impl/codegen/atm_gcc_sync.h", - "include/grpc/impl/codegen/atm_windows.h", - "include/grpc/impl/codegen/byte_buffer.h", - "include/grpc/impl/codegen/byte_buffer_reader.h", - "include/grpc/impl/codegen/compression_types.h", - "include/grpc/impl/codegen/connectivity_state.h", - "include/grpc/impl/codegen/grpc_types.h", - "include/grpc/impl/codegen/log.h", - "include/grpc/impl/codegen/port_platform.h", - "include/grpc/impl/codegen/propagation_bits.h", - "include/grpc/impl/codegen/slice.h", - "include/grpc/impl/codegen/slice_buffer.h", - "include/grpc/impl/codegen/status.h", - "include/grpc/impl/codegen/sync.h", - "include/grpc/impl/codegen/sync_generic.h", - "include/grpc/impl/codegen/sync_posix.h", - "include/grpc/impl/codegen/sync_windows.h", - "include/grpc/impl/codegen/time.h", - ], - includes = [ - ".", - "include", - ], - deps = [ - ":grpc++", - ], -) - -cc_library( - name = "grpc++_unsecure", - srcs = [ - "src/cpp/client/channel.cc", - "src/cpp/client/client_context.cc", - "src/cpp/client/create_channel.cc", - "src/cpp/client/create_channel_internal.cc", - "src/cpp/client/create_channel_internal.h", - "src/cpp/client/create_channel_posix.cc", - "src/cpp/client/credentials.cc", - "src/cpp/client/generic_stub.cc", - "src/cpp/client/insecure_credentials.cc", - "src/cpp/codegen/codegen_init.cc", - "src/cpp/common/channel_arguments.cc", - "src/cpp/common/completion_queue.cc", - "src/cpp/common/core_codegen.cc", - "src/cpp/common/insecure_create_auth_context.cc", - "src/cpp/common/rpc_method.cc", - "src/cpp/server/async_generic_service.cc", - "src/cpp/server/create_default_thread_pool.cc", - "src/cpp/server/dynamic_thread_pool.cc", - "src/cpp/server/dynamic_thread_pool.h", - "src/cpp/server/insecure_server_credentials.cc", - "src/cpp/server/server.cc", - "src/cpp/server/server_builder.cc", - "src/cpp/server/server_context.cc", - "src/cpp/server/server_credentials.cc", - "src/cpp/server/server_posix.cc", - "src/cpp/server/thread_pool_interface.h", - "src/cpp/util/byte_buffer.cc", - "src/cpp/util/slice.cc", - "src/cpp/util/status.cc", - "src/cpp/util/string_ref.cc", - "src/cpp/util/time.cc", - ], - hdrs = [ - "include/grpc++/alarm.h", - "include/grpc++/channel.h", - "include/grpc++/client_context.h", - "include/grpc++/completion_queue.h", - "include/grpc++/create_channel.h", - "include/grpc++/create_channel_posix.h", - "include/grpc++/generic/async_generic_service.h", - "include/grpc++/generic/generic_stub.h", - "include/grpc++/grpc++.h", - "include/grpc++/impl/call.h", - "include/grpc++/impl/client_unary_call.h", - "include/grpc++/impl/codegen/async_stream.h", - "include/grpc++/impl/codegen/async_unary_call.h", - "include/grpc++/impl/codegen/call.h", - "include/grpc++/impl/codegen/call_hook.h", - "include/grpc++/impl/codegen/channel_interface.h", - "include/grpc++/impl/codegen/client_context.h", - "include/grpc++/impl/codegen/client_unary_call.h", - "include/grpc++/impl/codegen/completion_queue.h", - "include/grpc++/impl/codegen/completion_queue_tag.h", - "include/grpc++/impl/codegen/config.h", - "include/grpc++/impl/codegen/config_protobuf.h", - "include/grpc++/impl/codegen/core_codegen.h", - "include/grpc++/impl/codegen/core_codegen_interface.h", - "include/grpc++/impl/codegen/create_auth_context.h", - "include/grpc++/impl/codegen/grpc_library.h", - "include/grpc++/impl/codegen/method_handler_impl.h", - "include/grpc++/impl/codegen/proto_utils.h", - "include/grpc++/impl/codegen/rpc_method.h", - "include/grpc++/impl/codegen/rpc_service_method.h", - "include/grpc++/impl/codegen/security/auth_context.h", - "include/grpc++/impl/codegen/serialization_traits.h", - "include/grpc++/impl/codegen/server_context.h", - "include/grpc++/impl/codegen/server_interface.h", - "include/grpc++/impl/codegen/service_type.h", - "include/grpc++/impl/codegen/status.h", - "include/grpc++/impl/codegen/status_code_enum.h", - "include/grpc++/impl/codegen/string_ref.h", - "include/grpc++/impl/codegen/stub_options.h", - "include/grpc++/impl/codegen/sync.h", - "include/grpc++/impl/codegen/sync_cxx11.h", - "include/grpc++/impl/codegen/sync_no_cxx11.h", - "include/grpc++/impl/codegen/sync_stream.h", - "include/grpc++/impl/codegen/time.h", - "include/grpc++/impl/grpc_library.h", - "include/grpc++/impl/method_handler_impl.h", - "include/grpc++/impl/rpc_method.h", - "include/grpc++/impl/rpc_service_method.h", - "include/grpc++/impl/serialization_traits.h", - "include/grpc++/impl/server_builder_option.h", - "include/grpc++/impl/server_builder_plugin.h", - "include/grpc++/impl/server_initializer.h", - "include/grpc++/impl/service_type.h", - "include/grpc++/impl/sync.h", - "include/grpc++/impl/sync_cxx11.h", - "include/grpc++/impl/sync_no_cxx11.h", - "include/grpc++/impl/thd.h", - "include/grpc++/impl/thd_cxx11.h", - "include/grpc++/impl/thd_no_cxx11.h", - "include/grpc++/security/auth_context.h", - "include/grpc++/security/auth_metadata_processor.h", - "include/grpc++/security/credentials.h", - "include/grpc++/security/server_credentials.h", - "include/grpc++/server.h", - "include/grpc++/server_builder.h", - "include/grpc++/server_context.h", - "include/grpc++/server_posix.h", - "include/grpc++/support/async_stream.h", - "include/grpc++/support/async_unary_call.h", - "include/grpc++/support/byte_buffer.h", - "include/grpc++/support/channel_arguments.h", - "include/grpc++/support/config.h", - "include/grpc++/support/slice.h", - "include/grpc++/support/status.h", - "include/grpc++/support/status_code_enum.h", - "include/grpc++/support/string_ref.h", - "include/grpc++/support/stub_options.h", - "include/grpc++/support/sync_stream.h", - "include/grpc++/support/time.h", - ], - includes = [ - ".", - "include", - ], - linkopts = ["-lpthread"], - deps = [ - ":gpr", - ":grpc_unsecure", - "//external:protobuf_clib", - ], -) - -cc_library( - name = "grpc_plugin_support", - srcs = [ - "src/compiler/config.h", - "src/compiler/cpp_generator.cc", - "src/compiler/cpp_generator.h", - "src/compiler/cpp_generator_helpers.h", - "src/compiler/csharp_generator.cc", - "src/compiler/csharp_generator.h", - "src/compiler/csharp_generator_helpers.h", - "src/compiler/generator_helpers.h", - "src/compiler/node_generator.cc", - "src/compiler/node_generator.h", - "src/compiler/node_generator_helpers.h", - "src/compiler/objective_c_generator.cc", - "src/compiler/objective_c_generator.h", - "src/compiler/objective_c_generator_helpers.h", - "src/compiler/python_generator.cc", - "src/compiler/python_generator.h", - "src/compiler/ruby_generator.cc", - "src/compiler/ruby_generator.h", - "src/compiler/ruby_generator_helpers-inl.h", - "src/compiler/ruby_generator_map-inl.h", - "src/compiler/ruby_generator_string-inl.h", - ], - hdrs = [ - "include/grpc++/impl/codegen/config_protobuf.h", - ], - includes = [ - ".", - "include", - ], - deps = [ - "//external:protobuf_compiler", - ], -) - -cc_library( - name = "grpc_csharp_ext", - srcs = [ - "src/csharp/ext/grpc_csharp_ext.c", - ], - hdrs = [ - ], - includes = [ - ".", - "include", - ], - deps = [ - ":gpr", - ":grpc", - ], -) - -objc_library( - name = "gpr_objc", - srcs = [ - "src/core/lib/profiling/basic_timers.c", - "src/core/lib/profiling/stap_timers.c", - "src/core/lib/support/alloc.c", - "src/core/lib/support/avl.c", - "src/core/lib/support/backoff.c", - "src/core/lib/support/cmdline.c", - "src/core/lib/support/cpu_iphone.c", - "src/core/lib/support/cpu_linux.c", - "src/core/lib/support/cpu_posix.c", - "src/core/lib/support/cpu_windows.c", - "src/core/lib/support/env_linux.c", - "src/core/lib/support/env_posix.c", - "src/core/lib/support/env_windows.c", - "src/core/lib/support/histogram.c", - "src/core/lib/support/host_port.c", - "src/core/lib/support/log.c", - "src/core/lib/support/log_android.c", - "src/core/lib/support/log_linux.c", - "src/core/lib/support/log_posix.c", - "src/core/lib/support/log_windows.c", - "src/core/lib/support/murmur_hash.c", - "src/core/lib/support/slice.c", - "src/core/lib/support/slice_buffer.c", - "src/core/lib/support/stack_lockfree.c", - "src/core/lib/support/string.c", - "src/core/lib/support/string_posix.c", - "src/core/lib/support/string_util_windows.c", - "src/core/lib/support/string_windows.c", - "src/core/lib/support/subprocess_posix.c", - "src/core/lib/support/subprocess_windows.c", - "src/core/lib/support/sync.c", - "src/core/lib/support/sync_posix.c", - "src/core/lib/support/sync_windows.c", - "src/core/lib/support/thd.c", - "src/core/lib/support/thd_posix.c", - "src/core/lib/support/thd_windows.c", - "src/core/lib/support/time.c", - "src/core/lib/support/time_posix.c", - "src/core/lib/support/time_precise.c", - "src/core/lib/support/time_windows.c", - "src/core/lib/support/tls_pthread.c", - "src/core/lib/support/tmpfile_msys.c", - "src/core/lib/support/tmpfile_posix.c", - "src/core/lib/support/tmpfile_windows.c", - "src/core/lib/support/wrap_memcpy.c", - ], - hdrs = [ - "include/grpc/impl/codegen/alloc.h", - "include/grpc/impl/codegen/atm.h", - "include/grpc/impl/codegen/atm_gcc_atomic.h", - "include/grpc/impl/codegen/atm_gcc_sync.h", - "include/grpc/impl/codegen/atm_windows.h", - "include/grpc/impl/codegen/log.h", - "include/grpc/impl/codegen/port_platform.h", - "include/grpc/impl/codegen/slice.h", - "include/grpc/impl/codegen/slice_buffer.h", - "include/grpc/impl/codegen/sync.h", - "include/grpc/impl/codegen/sync_generic.h", - "include/grpc/impl/codegen/sync_posix.h", - "include/grpc/impl/codegen/sync_windows.h", - "include/grpc/impl/codegen/time.h", - "include/grpc/support/alloc.h", - "include/grpc/support/atm.h", - "include/grpc/support/atm_gcc_atomic.h", - "include/grpc/support/atm_gcc_sync.h", - "include/grpc/support/atm_windows.h", - "include/grpc/support/avl.h", - "include/grpc/support/cmdline.h", - "include/grpc/support/cpu.h", - "include/grpc/support/histogram.h", - "include/grpc/support/host_port.h", - "include/grpc/support/log.h", - "include/grpc/support/log_windows.h", - "include/grpc/support/port_platform.h", - "include/grpc/support/slice.h", - "include/grpc/support/slice_buffer.h", - "include/grpc/support/string_util.h", - "include/grpc/support/subprocess.h", - "include/grpc/support/sync.h", - "include/grpc/support/sync_generic.h", - "include/grpc/support/sync_posix.h", - "include/grpc/support/sync_windows.h", - "include/grpc/support/thd.h", - "include/grpc/support/time.h", - "include/grpc/support/tls.h", - "include/grpc/support/tls_gcc.h", - "include/grpc/support/tls_msvc.h", - "include/grpc/support/tls_pthread.h", - "include/grpc/support/useful.h", - "src/core/lib/profiling/timers.h", - "src/core/lib/support/backoff.h", - "src/core/lib/support/block_annotate.h", - "src/core/lib/support/env.h", - "src/core/lib/support/murmur_hash.h", - "src/core/lib/support/stack_lockfree.h", - "src/core/lib/support/string.h", - "src/core/lib/support/string_windows.h", - "src/core/lib/support/thd_internal.h", - "src/core/lib/support/time_precise.h", - "src/core/lib/support/tmpfile.h", - ], - includes = [ - ".", - "include", - ], - deps = [ - ], -) - -objc_library( - name = "grpc_objc", - srcs = [ - "src/core/ext/census/context.c", - "src/core/ext/census/gen/census.pb.c", - "src/core/ext/census/grpc_context.c", - "src/core/ext/census/grpc_filter.c", - "src/core/ext/census/grpc_plugin.c", - "src/core/ext/census/initialize.c", - "src/core/ext/census/mlog.c", - "src/core/ext/census/operation.c", - "src/core/ext/census/placeholders.c", - "src/core/ext/census/tracing.c", - "src/core/ext/client_config/channel_connectivity.c", - "src/core/ext/client_config/client_channel.c", - "src/core/ext/client_config/client_channel_factory.c", - "src/core/ext/client_config/client_config.c", - "src/core/ext/client_config/client_config_plugin.c", - "src/core/ext/client_config/connector.c", - "src/core/ext/client_config/default_initial_connect_string.c", - "src/core/ext/client_config/initial_connect_string.c", - "src/core/ext/client_config/lb_policy.c", - "src/core/ext/client_config/lb_policy_factory.c", - "src/core/ext/client_config/lb_policy_registry.c", - "src/core/ext/client_config/parse_address.c", - "src/core/ext/client_config/resolver.c", - "src/core/ext/client_config/resolver_factory.c", - "src/core/ext/client_config/resolver_registry.c", - "src/core/ext/client_config/subchannel.c", - "src/core/ext/client_config/subchannel_call_holder.c", - "src/core/ext/client_config/subchannel_index.c", - "src/core/ext/client_config/uri_parser.c", - "src/core/ext/lb_policy/grpclb/load_balancer_api.c", - "src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c", - "src/core/ext/lb_policy/pick_first/pick_first.c", - "src/core/ext/lb_policy/round_robin/round_robin.c", - "src/core/ext/load_reporting/load_reporting.c", - "src/core/ext/load_reporting/load_reporting_filter.c", - "src/core/ext/resolver/dns/native/dns_resolver.c", - "src/core/ext/resolver/sockaddr/sockaddr_resolver.c", - "src/core/ext/transport/chttp2/alpn/alpn.c", - "src/core/ext/transport/chttp2/client/insecure/channel_create.c", - "src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c", - "src/core/ext/transport/chttp2/client/secure/secure_channel_create.c", - "src/core/ext/transport/chttp2/server/insecure/server_chttp2.c", - "src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c", - "src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c", - "src/core/ext/transport/chttp2/transport/bin_decoder.c", - "src/core/ext/transport/chttp2/transport/bin_encoder.c", - "src/core/ext/transport/chttp2/transport/chttp2_plugin.c", - "src/core/ext/transport/chttp2/transport/chttp2_transport.c", - "src/core/ext/transport/chttp2/transport/frame_data.c", - "src/core/ext/transport/chttp2/transport/frame_goaway.c", - "src/core/ext/transport/chttp2/transport/frame_ping.c", - "src/core/ext/transport/chttp2/transport/frame_rst_stream.c", - "src/core/ext/transport/chttp2/transport/frame_settings.c", - "src/core/ext/transport/chttp2/transport/frame_window_update.c", - "src/core/ext/transport/chttp2/transport/hpack_encoder.c", - "src/core/ext/transport/chttp2/transport/hpack_parser.c", - "src/core/ext/transport/chttp2/transport/hpack_table.c", - "src/core/ext/transport/chttp2/transport/huffsyms.c", - "src/core/ext/transport/chttp2/transport/incoming_metadata.c", - "src/core/ext/transport/chttp2/transport/parsing.c", - "src/core/ext/transport/chttp2/transport/status_conversion.c", - "src/core/ext/transport/chttp2/transport/stream_lists.c", - "src/core/ext/transport/chttp2/transport/stream_map.c", - "src/core/ext/transport/chttp2/transport/timeout_encoding.c", - "src/core/ext/transport/chttp2/transport/varint.c", - "src/core/ext/transport/chttp2/transport/writing.c", - "src/core/lib/channel/channel_args.c", - "src/core/lib/channel/channel_stack.c", - "src/core/lib/channel/channel_stack_builder.c", - "src/core/lib/channel/compress_filter.c", - "src/core/lib/channel/connected_channel.c", - "src/core/lib/channel/http_client_filter.c", - "src/core/lib/channel/http_server_filter.c", - "src/core/lib/compression/compression.c", - "src/core/lib/compression/message_compress.c", - "src/core/lib/debug/trace.c", - "src/core/lib/http/format_request.c", - "src/core/lib/http/httpcli.c", - "src/core/lib/http/httpcli_security_connector.c", - "src/core/lib/http/parser.c", - "src/core/lib/iomgr/closure.c", - "src/core/lib/iomgr/endpoint.c", - "src/core/lib/iomgr/endpoint_pair_posix.c", - "src/core/lib/iomgr/endpoint_pair_windows.c", - "src/core/lib/iomgr/error.c", - "src/core/lib/iomgr/ev_epoll_linux.c", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.c", - "src/core/lib/iomgr/ev_poll_posix.c", - "src/core/lib/iomgr/ev_posix.c", - "src/core/lib/iomgr/exec_ctx.c", - "src/core/lib/iomgr/executor.c", - "src/core/lib/iomgr/iocp_windows.c", - "src/core/lib/iomgr/iomgr.c", - "src/core/lib/iomgr/iomgr_posix.c", - "src/core/lib/iomgr/iomgr_windows.c", - "src/core/lib/iomgr/load_file.c", - "src/core/lib/iomgr/network_status_tracker.c", - "src/core/lib/iomgr/polling_entity.c", - "src/core/lib/iomgr/pollset_set_windows.c", - "src/core/lib/iomgr/pollset_windows.c", - "src/core/lib/iomgr/resolve_address_posix.c", - "src/core/lib/iomgr/resolve_address_windows.c", - "src/core/lib/iomgr/sockaddr_utils.c", - "src/core/lib/iomgr/socket_utils_common_posix.c", - "src/core/lib/iomgr/socket_utils_linux.c", - "src/core/lib/iomgr/socket_utils_posix.c", - "src/core/lib/iomgr/socket_windows.c", - "src/core/lib/iomgr/tcp_client_posix.c", - "src/core/lib/iomgr/tcp_client_windows.c", - "src/core/lib/iomgr/tcp_posix.c", - "src/core/lib/iomgr/tcp_server_posix.c", - "src/core/lib/iomgr/tcp_server_windows.c", - "src/core/lib/iomgr/tcp_windows.c", - "src/core/lib/iomgr/time_averaged_stats.c", - "src/core/lib/iomgr/timer.c", - "src/core/lib/iomgr/timer_heap.c", - "src/core/lib/iomgr/udp_server.c", - "src/core/lib/iomgr/unix_sockets_posix.c", - "src/core/lib/iomgr/unix_sockets_posix_noop.c", - "src/core/lib/iomgr/wakeup_fd_eventfd.c", - "src/core/lib/iomgr/wakeup_fd_nospecial.c", - "src/core/lib/iomgr/wakeup_fd_pipe.c", - "src/core/lib/iomgr/wakeup_fd_posix.c", - "src/core/lib/iomgr/workqueue_posix.c", - "src/core/lib/iomgr/workqueue_windows.c", - "src/core/lib/json/json.c", - "src/core/lib/json/json_reader.c", - "src/core/lib/json/json_string.c", - "src/core/lib/json/json_writer.c", - "src/core/lib/security/context/security_context.c", - "src/core/lib/security/credentials/composite/composite_credentials.c", - "src/core/lib/security/credentials/credentials.c", - "src/core/lib/security/credentials/credentials_metadata.c", - "src/core/lib/security/credentials/fake/fake_credentials.c", - "src/core/lib/security/credentials/google_default/credentials_posix.c", - "src/core/lib/security/credentials/google_default/credentials_windows.c", - "src/core/lib/security/credentials/google_default/google_default_credentials.c", - "src/core/lib/security/credentials/iam/iam_credentials.c", - "src/core/lib/security/credentials/jwt/json_token.c", - "src/core/lib/security/credentials/jwt/jwt_credentials.c", - "src/core/lib/security/credentials/jwt/jwt_verifier.c", - "src/core/lib/security/credentials/oauth2/oauth2_credentials.c", - "src/core/lib/security/credentials/plugin/plugin_credentials.c", - "src/core/lib/security/credentials/ssl/ssl_credentials.c", - "src/core/lib/security/transport/client_auth_filter.c", - "src/core/lib/security/transport/handshake.c", - "src/core/lib/security/transport/secure_endpoint.c", - "src/core/lib/security/transport/security_connector.c", - "src/core/lib/security/transport/server_auth_filter.c", - "src/core/lib/security/transport/tsi_error.c", - "src/core/lib/security/util/b64.c", - "src/core/lib/security/util/json_util.c", - "src/core/lib/surface/alarm.c", - "src/core/lib/surface/api_trace.c", - "src/core/lib/surface/byte_buffer.c", - "src/core/lib/surface/byte_buffer_reader.c", - "src/core/lib/surface/call.c", - "src/core/lib/surface/call_details.c", - "src/core/lib/surface/call_log_batch.c", - "src/core/lib/surface/channel.c", - "src/core/lib/surface/channel_init.c", - "src/core/lib/surface/channel_ping.c", - "src/core/lib/surface/channel_stack_type.c", - "src/core/lib/surface/completion_queue.c", - "src/core/lib/surface/event_string.c", - "src/core/lib/surface/init.c", - "src/core/lib/surface/init_secure.c", - "src/core/lib/surface/lame_client.c", - "src/core/lib/surface/metadata_array.c", - "src/core/lib/surface/server.c", - "src/core/lib/surface/validate_metadata.c", - "src/core/lib/surface/version.c", - "src/core/lib/transport/byte_stream.c", - "src/core/lib/transport/connectivity_state.c", - "src/core/lib/transport/metadata.c", - "src/core/lib/transport/metadata_batch.c", - "src/core/lib/transport/static_metadata.c", - "src/core/lib/transport/transport.c", - "src/core/lib/transport/transport_op_string.c", - "src/core/lib/tsi/fake_transport_security.c", - "src/core/lib/tsi/ssl_transport_security.c", - "src/core/lib/tsi/transport_security.c", - "src/core/plugin_registry/grpc_plugin_registry.c", - ], - hdrs = [ - "include/grpc/byte_buffer.h", - "include/grpc/byte_buffer_reader.h", - "include/grpc/census.h", - "include/grpc/compression.h", - "include/grpc/grpc.h", - "include/grpc/grpc_posix.h", - "include/grpc/grpc_security.h", - "include/grpc/grpc_security_constants.h", - "include/grpc/impl/codegen/alloc.h", - "include/grpc/impl/codegen/atm.h", - "include/grpc/impl/codegen/atm_gcc_atomic.h", - "include/grpc/impl/codegen/atm_gcc_sync.h", - "include/grpc/impl/codegen/atm_windows.h", - "include/grpc/impl/codegen/byte_buffer.h", - "include/grpc/impl/codegen/byte_buffer_reader.h", - "include/grpc/impl/codegen/compression_types.h", - "include/grpc/impl/codegen/connectivity_state.h", - "include/grpc/impl/codegen/grpc_types.h", - "include/grpc/impl/codegen/log.h", - "include/grpc/impl/codegen/port_platform.h", - "include/grpc/impl/codegen/propagation_bits.h", - "include/grpc/impl/codegen/slice.h", - "include/grpc/impl/codegen/slice_buffer.h", - "include/grpc/impl/codegen/status.h", - "include/grpc/impl/codegen/sync.h", - "include/grpc/impl/codegen/sync_generic.h", - "include/grpc/impl/codegen/sync_posix.h", - "include/grpc/impl/codegen/sync_windows.h", - "include/grpc/impl/codegen/time.h", - "include/grpc/status.h", - "src/core/ext/census/aggregation.h", - "src/core/ext/census/census_interface.h", - "src/core/ext/census/census_rpc_stats.h", - "src/core/ext/census/gen/census.pb.h", - "src/core/ext/census/grpc_filter.h", - "src/core/ext/census/mlog.h", - "src/core/ext/census/rpc_metric_id.h", - "src/core/ext/client_config/client_channel.h", - "src/core/ext/client_config/client_channel_factory.h", - "src/core/ext/client_config/client_config.h", - "src/core/ext/client_config/connector.h", - "src/core/ext/client_config/initial_connect_string.h", - "src/core/ext/client_config/lb_policy.h", - "src/core/ext/client_config/lb_policy_factory.h", - "src/core/ext/client_config/lb_policy_registry.h", - "src/core/ext/client_config/parse_address.h", - "src/core/ext/client_config/resolver.h", - "src/core/ext/client_config/resolver_factory.h", - "src/core/ext/client_config/resolver_registry.h", - "src/core/ext/client_config/subchannel.h", - "src/core/ext/client_config/subchannel_call_holder.h", - "src/core/ext/client_config/subchannel_index.h", - "src/core/ext/client_config/uri_parser.h", - "src/core/ext/lb_policy/grpclb/load_balancer_api.h", - "src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.h", - "src/core/ext/load_reporting/load_reporting.h", - "src/core/ext/load_reporting/load_reporting_filter.h", - "src/core/ext/transport/chttp2/alpn/alpn.h", - "src/core/ext/transport/chttp2/transport/bin_decoder.h", - "src/core/ext/transport/chttp2/transport/bin_encoder.h", - "src/core/ext/transport/chttp2/transport/chttp2_transport.h", - "src/core/ext/transport/chttp2/transport/frame.h", - "src/core/ext/transport/chttp2/transport/frame_data.h", - "src/core/ext/transport/chttp2/transport/frame_goaway.h", - "src/core/ext/transport/chttp2/transport/frame_ping.h", - "src/core/ext/transport/chttp2/transport/frame_rst_stream.h", - "src/core/ext/transport/chttp2/transport/frame_settings.h", - "src/core/ext/transport/chttp2/transport/frame_window_update.h", - "src/core/ext/transport/chttp2/transport/hpack_encoder.h", - "src/core/ext/transport/chttp2/transport/hpack_parser.h", - "src/core/ext/transport/chttp2/transport/hpack_table.h", - "src/core/ext/transport/chttp2/transport/http2_errors.h", - "src/core/ext/transport/chttp2/transport/huffsyms.h", - "src/core/ext/transport/chttp2/transport/incoming_metadata.h", - "src/core/ext/transport/chttp2/transport/internal.h", - "src/core/ext/transport/chttp2/transport/status_conversion.h", - "src/core/ext/transport/chttp2/transport/stream_map.h", - "src/core/ext/transport/chttp2/transport/timeout_encoding.h", - "src/core/ext/transport/chttp2/transport/varint.h", - "src/core/lib/channel/channel_args.h", - "src/core/lib/channel/channel_stack.h", - "src/core/lib/channel/channel_stack_builder.h", - "src/core/lib/channel/compress_filter.h", - "src/core/lib/channel/connected_channel.h", - "src/core/lib/channel/context.h", - "src/core/lib/channel/http_client_filter.h", - "src/core/lib/channel/http_server_filter.h", - "src/core/lib/compression/algorithm_metadata.h", - "src/core/lib/compression/message_compress.h", - "src/core/lib/debug/trace.h", - "src/core/lib/http/format_request.h", - "src/core/lib/http/httpcli.h", - "src/core/lib/http/parser.h", - "src/core/lib/iomgr/closure.h", - "src/core/lib/iomgr/endpoint.h", - "src/core/lib/iomgr/endpoint_pair.h", - "src/core/lib/iomgr/error.h", - "src/core/lib/iomgr/ev_epoll_linux.h", - "src/core/lib/iomgr/ev_poll_and_epoll_posix.h", - "src/core/lib/iomgr/ev_poll_posix.h", - "src/core/lib/iomgr/ev_posix.h", - "src/core/lib/iomgr/exec_ctx.h", - "src/core/lib/iomgr/executor.h", - "src/core/lib/iomgr/iocp_windows.h", - "src/core/lib/iomgr/iomgr.h", - "src/core/lib/iomgr/iomgr_internal.h", - "src/core/lib/iomgr/iomgr_posix.h", - "src/core/lib/iomgr/load_file.h", - "src/core/lib/iomgr/network_status_tracker.h", - "src/core/lib/iomgr/polling_entity.h", - "src/core/lib/iomgr/pollset.h", - "src/core/lib/iomgr/pollset_set.h", - "src/core/lib/iomgr/pollset_set_windows.h", - "src/core/lib/iomgr/pollset_windows.h", - "src/core/lib/iomgr/resolve_address.h", - "src/core/lib/iomgr/sockaddr.h", - "src/core/lib/iomgr/sockaddr_posix.h", - "src/core/lib/iomgr/sockaddr_utils.h", - "src/core/lib/iomgr/sockaddr_windows.h", - "src/core/lib/iomgr/socket_utils_posix.h", - "src/core/lib/iomgr/socket_windows.h", - "src/core/lib/iomgr/tcp_client.h", - "src/core/lib/iomgr/tcp_posix.h", - "src/core/lib/iomgr/tcp_server.h", - "src/core/lib/iomgr/tcp_windows.h", - "src/core/lib/iomgr/time_averaged_stats.h", - "src/core/lib/iomgr/timer.h", - "src/core/lib/iomgr/timer_heap.h", - "src/core/lib/iomgr/udp_server.h", - "src/core/lib/iomgr/unix_sockets_posix.h", - "src/core/lib/iomgr/wakeup_fd_pipe.h", - "src/core/lib/iomgr/wakeup_fd_posix.h", - "src/core/lib/iomgr/workqueue.h", - "src/core/lib/iomgr/workqueue_posix.h", - "src/core/lib/iomgr/workqueue_windows.h", - "src/core/lib/json/json.h", - "src/core/lib/json/json_common.h", - "src/core/lib/json/json_reader.h", - "src/core/lib/json/json_writer.h", - "src/core/lib/security/context/security_context.h", - "src/core/lib/security/credentials/composite/composite_credentials.h", - "src/core/lib/security/credentials/credentials.h", - "src/core/lib/security/credentials/fake/fake_credentials.h", - "src/core/lib/security/credentials/google_default/google_default_credentials.h", - "src/core/lib/security/credentials/iam/iam_credentials.h", - "src/core/lib/security/credentials/jwt/json_token.h", - "src/core/lib/security/credentials/jwt/jwt_credentials.h", - "src/core/lib/security/credentials/jwt/jwt_verifier.h", - "src/core/lib/security/credentials/oauth2/oauth2_credentials.h", - "src/core/lib/security/credentials/plugin/plugin_credentials.h", - "src/core/lib/security/credentials/ssl/ssl_credentials.h", - "src/core/lib/security/transport/auth_filters.h", - "src/core/lib/security/transport/handshake.h", - "src/core/lib/security/transport/secure_endpoint.h", - "src/core/lib/security/transport/security_connector.h", - "src/core/lib/security/transport/tsi_error.h", - "src/core/lib/security/util/b64.h", - "src/core/lib/security/util/json_util.h", - "src/core/lib/surface/api_trace.h", - "src/core/lib/surface/call.h", - "src/core/lib/surface/call_test_only.h", - "src/core/lib/surface/channel.h", - "src/core/lib/surface/channel_init.h", - "src/core/lib/surface/channel_stack_type.h", - "src/core/lib/surface/completion_queue.h", - "src/core/lib/surface/event_string.h", - "src/core/lib/surface/init.h", - "src/core/lib/surface/lame_client.h", - "src/core/lib/surface/server.h", - "src/core/lib/transport/byte_stream.h", - "src/core/lib/transport/connectivity_state.h", - "src/core/lib/transport/metadata.h", - "src/core/lib/transport/metadata_batch.h", - "src/core/lib/transport/static_metadata.h", - "src/core/lib/transport/transport.h", - "src/core/lib/transport/transport_impl.h", - "src/core/lib/tsi/fake_transport_security.h", - "src/core/lib/tsi/ssl_transport_security.h", - "src/core/lib/tsi/ssl_types.h", - "src/core/lib/tsi/transport_security.h", - "src/core/lib/tsi/transport_security_interface.h", - "third_party/nanopb/pb.h", - "third_party/nanopb/pb_decode.h", - "third_party/nanopb/pb_encode.h", - ], - includes = [ - ".", - "include", - ], - sdk_dylibs = ["libz"], - deps = [ - ":gpr_objc", - "//external:libssl_objc", - "//external:nanopb", - ], -) - -cc_binary( - name = "grpc_cpp_plugin", - srcs = [ - "src/compiler/cpp_plugin.cc", - ], - deps = [ - ":grpc_plugin_support", - "//external:protobuf_compiler", - ], -) - -cc_binary( - name = "grpc_csharp_plugin", - srcs = [ - "src/compiler/csharp_plugin.cc", - ], - deps = [ - ":grpc_plugin_support", - "//external:protobuf_compiler", - ], -) - -cc_binary( - name = "grpc_node_plugin", - srcs = [ - "src/compiler/node_plugin.cc", - ], - deps = [ - ":grpc_plugin_support", - "//external:protobuf_compiler", - ], -) - -cc_binary( - name = "grpc_objective_c_plugin", - srcs = [ - "src/compiler/objective_c_plugin.cc", - ], - deps = [ - ":grpc_plugin_support", - "//external:protobuf_compiler", - ], -) - -cc_binary( - name = "grpc_python_plugin", - srcs = [ - "src/compiler/python_plugin.cc", - ], - deps = [ - ":grpc_plugin_support", - "//external:protobuf_compiler", - ], -) - -cc_binary( - name = "grpc_ruby_plugin", - srcs = [ - "src/compiler/ruby_plugin.cc", - ], - deps = [ - ":grpc_plugin_support", - "//external:protobuf_compiler", - ], -) - -objc_path = "src/objective-c" - -rx_library_path = objc_path + "/RxLibrary" - -objc_library( - name = "rx_library", - srcs = glob([ - rx_library_path + "/*.m", - rx_library_path + "/transformations/*.m", - ]), - hdrs = glob([ - rx_library_path + "/*.h", - rx_library_path + "/transformations/*.h", - ]), - includes = [objc_path], - deps = [ - ":rx_library_private", - ], -) - -objc_library( - name = "rx_library_private", - srcs = glob([rx_library_path + "/private/*.m"]), - hdrs = glob([rx_library_path + "/private/*.h"]), - visibility = ["//visibility:private"], -) - -objc_client_path = objc_path + "/GRPCClient" - -objc_library( - name = "grpc_client", - srcs = glob([ - objc_client_path + "/*.m", - objc_client_path + "/private/*.m", - ]), - hdrs = glob([ - objc_client_path + "/*.h", - objc_client_path + "/private/*.h", - ]), - bundles = [":gRPCCertificates"], - includes = [objc_path], - deps = [ - ":grpc_objc", - ":rx_library", - ], -) - -objc_bundle_library( - # The choice of name is signicant here, since it determines the bundle name. - name = "gRPCCertificates", - resources = ["etc/roots.pem"], -) - -proto_objc_rpc_path = objc_path + "/ProtoRPC" - -objc_library( - name = "proto_objc_rpc", - srcs = glob([ - proto_objc_rpc_path + "/*.m", - ]), - hdrs = glob([ - proto_objc_rpc_path + "/*.h", - ]), - includes = [objc_path], - deps = [ - ":grpc_client", - ":rx_library", - "//external:protobuf_objc", - ], -) diff --git a/tensorflow/tensorboard/scripts/__init__.py b/third_party/grpc/BUILD similarity index 100% rename from tensorflow/tensorboard/scripts/__init__.py rename to third_party/grpc/BUILD diff --git a/third_party/grpc/grpc.patch b/third_party/grpc/grpc.patch new file mode 100644 index 0000000000000000000000000000000000000000..6e5b4b02fba2c4c98c82a1366f090dc985bbcda0 --- /dev/null +++ b/third_party/grpc/grpc.patch @@ -0,0 +1,76 @@ +diff --git a/bazel/grpc_build_system.bzl b/bazel/grpc_build_system.bzl +index f793cae56d..0295adb8ab 100644 +--- a/bazel/grpc_build_system.bzl ++++ b/bazel/grpc_build_system.bzl +@@ -80,7 +80,7 @@ def grpc_cc_test(name, srcs = [], deps = [], external_deps = [], args = [], data + linkopts = ["-pthread"], + ) + +-def grpc_cc_binary(name, srcs = [], deps = [], external_deps = [], args = [], data = [], language = "C++", testonly = False, linkshared = False): ++def grpc_cc_binary(name, srcs = [], deps = [], external_deps = [], args = [], data = [], language = "C++", testonly = False, linkshared = False, linkopts = []): + copts = [] + if language.upper() == "C": + copts = ["-std=c99"] +@@ -93,7 +93,7 @@ def grpc_cc_binary(name, srcs = [], deps = [], external_deps = [], args = [], da + linkshared = linkshared, + deps = deps + ["//external:" + dep for dep in external_deps], + copts = copts, +- linkopts = ["-pthread"], ++ linkopts = ["-pthread"] + linkopts, + ) + + def grpc_generate_one_off_targets(): +diff --git a/src/core/plugin_registry/grpc_unsecure_plugin_registry.c b/src/core/plugin_registry/grpc_unsecure_plugin_registry.c +index 7eb599d81a..4cc2e30af4 100644 +--- a/src/core/plugin_registry/grpc_unsecure_plugin_registry.c ++++ b/src/core/plugin_registry/grpc_unsecure_plugin_registry.c +@@ -28,18 +28,12 @@ extern void grpc_client_channel_init(void); + extern void grpc_client_channel_shutdown(void); + extern void grpc_inproc_plugin_init(void); + extern void grpc_inproc_plugin_shutdown(void); +-extern void grpc_resolver_dns_ares_init(void); +-extern void grpc_resolver_dns_ares_shutdown(void); + extern void grpc_resolver_dns_native_init(void); + extern void grpc_resolver_dns_native_shutdown(void); + extern void grpc_resolver_sockaddr_init(void); + extern void grpc_resolver_sockaddr_shutdown(void); +-extern void grpc_resolver_fake_init(void); +-extern void grpc_resolver_fake_shutdown(void); + extern void grpc_load_reporting_plugin_init(void); + extern void grpc_load_reporting_plugin_shutdown(void); +-extern void grpc_lb_policy_grpclb_init(void); +-extern void grpc_lb_policy_grpclb_shutdown(void); + extern void grpc_lb_policy_pick_first_init(void); + extern void grpc_lb_policy_pick_first_shutdown(void); + extern void grpc_lb_policy_round_robin_init(void); +@@ -64,18 +58,12 @@ void grpc_register_built_in_plugins(void) { + grpc_client_channel_shutdown); + grpc_register_plugin(grpc_inproc_plugin_init, + grpc_inproc_plugin_shutdown); +- grpc_register_plugin(grpc_resolver_dns_ares_init, +- grpc_resolver_dns_ares_shutdown); + grpc_register_plugin(grpc_resolver_dns_native_init, + grpc_resolver_dns_native_shutdown); + grpc_register_plugin(grpc_resolver_sockaddr_init, + grpc_resolver_sockaddr_shutdown); +- grpc_register_plugin(grpc_resolver_fake_init, +- grpc_resolver_fake_shutdown); + grpc_register_plugin(grpc_load_reporting_plugin_init, + grpc_load_reporting_plugin_shutdown); +- grpc_register_plugin(grpc_lb_policy_grpclb_init, +- grpc_lb_policy_grpclb_shutdown); + grpc_register_plugin(grpc_lb_policy_pick_first_init, + grpc_lb_policy_pick_first_shutdown); + grpc_register_plugin(grpc_lb_policy_round_robin_init, +diff --git a/test/cpp/util/BUILD b/test/cpp/util/BUILD +index 33240f6f69..d2e1f67f06 100644 +--- a/test/cpp/util/BUILD ++++ b/test/cpp/util/BUILD +@@ -29,6 +29,7 @@ package( + grpc_cc_binary( + name = "testso.so", + srcs = [], ++ linkopts = ['-Wl,--no-undefined'], + linkshared = 1, + deps = ["//:grpc++_unsecure"], + ) diff --git a/third_party/html5lib.BUILD b/third_party/html5lib.BUILD deleted file mode 100644 index 63aac14f1559a86f626a5d99db973111f86f92ae..0000000000000000000000000000000000000000 --- a/third_party/html5lib.BUILD +++ /dev/null @@ -1,17 +0,0 @@ -# Description: -# Import of html5lib library. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # BSD-like notice-style license, see LICENSE file - -exports_files(["LICENSE"]) - -py_library( - name = "org_html5lib", - srcs = glob(["html5lib/**/*.py"]), - srcs_version = "PY2AND3", - deps = [ - "@six_archive//:six", - ], -) diff --git a/third_party/js.bzl b/third_party/js.bzl deleted file mode 100644 index 2d2339c95e5b537ae9ba0ebe8044808ebe411a36..0000000000000000000000000000000000000000 --- a/third_party/js.bzl +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard external JS dependencies (both infrastructure and frontend libs) -load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") - - - ############################################################################## - # TensorBoard Build Tools -def tensorboard_js_workspace(): - filegroup_external( - name = "org_nodejs", - # MIT with portions licensed: - # - MIT - # - Old MIT - # - 2-Clause-BSD - # - 3-Clause-BSD - # - ISC - # - Unicode - # - zlib - # - Artistic 2.0 - licenses = ["notice"], - sha256_urls_extract_macos = { - "47109a00cac344d80296c195451bb5eee7c21727fcef1594384ddfe1f852957a": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/node-v4.3.2-darwin-x64.tar.xz", - "http://nodejs.org/dist/v4.3.2/node-v4.3.2-darwin-x64.tar.xz", - ], - }, - sha256_urls_windows = { - "3d4cfca9dcec556a077a2324bf5bd165ea3e6e64a2bfd7fc6e7a1f0dc4eb552b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/nodejs/node/v4.3.2/LICENSE", - "https://raw.githubusercontent.com/nodejs/node/v4.3.2/LICENSE", - ], - "606c44c42d17866c017c50c0afadad411d9492ac4281d2431b937f881911614e": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/win-x64/node.exe", - "http://nodejs.org/dist/v4.3.2/win-x64/node.exe", - ], - "451a40570099a95488d6438f175813629e0430f87f23c8659bc18dc42494820a": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/win-x64/node.lib", - "http://nodejs.org/dist/v4.3.2/win-x64/node.lib", - ], - }, - sha256_urls_extract = { - "4350d0431b49697517c6cca5d66adf5f74eb9101c52f52ae959fa94225822d44": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/node-v4.3.2-linux-x64.tar.xz", - "http://nodejs.org/dist/v4.3.2/node-v4.3.2-linux-x64.tar.xz", - ], - }, - strip_prefix = { - "node-v4.3.2-darwin-x64.tar.xz": "node-v4.3.2-darwin-x64", - "node-v4.3.2-linux-x64.tar.xz": "node-v4.3.2-linux-x64", - }, - executable = [ - "node", - "node.exe", - ], - ) - - filegroup_external( - name = "com_microsoft_typescript", - licenses = ["notice"], # Apache 2.0 - sha256_urls = { - "a7d00bfd54525bc694b6e32f64c7ebcf5e6b7ae3657be5cc12767bce74654a47": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/LICENSE.txt", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/LICENSE.txt", - ], - "8465342c318f9c4cf0a29b109fa63ee3742dd4dc7080d05d9fd8f604814d04cf": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", - ], - "a67e36da3029d232e4e938e61a0a3302f516d71e7100d54dbf5362ad8618e994": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", - ], - }, - extra_build_file_content = "\n".join([ - "sh_binary(", - " name = \"tsc\",", - " srcs = [\"tsc.sh\"],", - " data = [", - " \"tsc.js\",", - " \"@org_nodejs\",", - " ],", - ")", - "", - "genrule(", - " name = \"tsc_sh\",", - " outs = [\"tsc.sh\"],", - " cmd = \"cat >$@ <<'EOF'\\n\" +", - " \"#!/bin/bash\\n\" +", - " \"NODE=external/org_nodejs/bin/node\\n\" +", - " \"if [[ -e external/org_nodejs/node.exe ]]; then\\n\" +", - " \" NODE=external/org_nodejs/node.exe\\n\" +", - " \"fi\\n\" +", - " \"exec $${NODE} external/com_microsoft_typescript/tsc.js \\\"$$@\\\"\\n\" +", - " \"EOF\",", - " executable = True,", - ")", - ]), - ) - - - native.new_http_archive( - name = "io_angular_clutz", - build_file = "//third_party:clutz.BUILD", - sha256 = "2981de41d1ff4774b544423da9a2cd8beb3be649e95aef2ef2fd83957300b3fe", - strip_prefix = "clutz-b0db5ade9bb535d387f05292316c422790c9848e", - urls = [ - "http://mirror.bazel.build/github.com/angular/clutz/archive/b0db5ade9bb535d387f05292316c422790c9848e.tar.gz", # 2017-05-22 - "https://github.com/angular/clutz/archive/b0db5ade9bb535d387f05292316c422790c9848e.tar.gz", - ], - ) - - filegroup_external( - name = "com_google_javascript_closure_compiler_externs", - licenses = ["notice"], # Apache 2.0 - sha256_urls_extract = { - "0f515a6ebfa138490b3c5ea9f3591ea1a7e4a930d3074f18b3eca86084ad9b66": [ - "http://mirror.bazel.build/github.com/google/closure-compiler/archive/b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz", # 2017-06-02 - "https://github.com/google/closure-compiler/archive/b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz", - ], - }, - strip_prefix = {"b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz": "closure-compiler-b37e6000001b0a6bf4c0be49024ebda14a8711d9/externs"}, - ) - - filegroup_external( - name = "com_google_javascript_closure_compiler_externs_polymer", - licenses = ["notice"], # Apache 2.0 - sha256_urls = { - "23baad9a200a717a821c6df504c84d3a893d7ea9102b14876eb80097e3b94292": [ - "http://mirror.bazel.build/raw.githubusercontent.com/google/closure-compiler/0e8dc5597a295ee259e3fecd98d6535dc621232f/contrib/externs/polymer-1.0.js", # 2017-05-27 - "https://raw.githubusercontent.com/google/closure-compiler/0e8dc5597a295ee259e3fecd98d6535dc621232f/contrib/externs/polymer-1.0.js", - ], - }, - ) - - filegroup_external( - name = "org_threejs", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "7aff264bd84c90bed3c72a4dc31db8c19151853c6df6980f52b01d3e9872c82d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/build/three.js", - "https://raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/build/three.js", - ], - "0e98ded15bb7fe398a655667e76b39909d36c0973a8950d01c62f65f93161c27": [ - "http://mirror.bazel.build/raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/examples/js/controls/OrbitControls.js", - "https://raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/examples/js/controls/OrbitControls.js", - ], - }, - ) - - ############################################################################## - # TensorBoard JavaScript Production Dependencies - web_library_external( - name = "com_lodash", - licenses = ["notice"], # MIT - sha256 = "0e88207e5f90af4ce8790d6e1e7d09d2702d81bce0bafdc253d18c0a5bf7661e", - urls = [ - "http://mirror.bazel.build/github.com/lodash/lodash/archive/3.10.1.tar.gz", - "https://github.com/lodash/lodash/archive/3.10.1.tar.gz", - ], - strip_prefix = "lodash-3.10.1", - path = "/lodash", - srcs = ["lodash.js"], - ) - - filegroup_external( - name = "com_numericjs", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "0e94aada97f12dee6118064add9170484c55022f5d53206ee4407143cd36ddcd": [ - "http://mirror.bazel.build/raw.githubusercontent.com/sloisel/numeric/v1.2.6/license.txt", - "https://raw.githubusercontent.com/sloisel/numeric/v1.2.6/license.txt", - ], - "dfaca3b8485bee735788cc6eebca82ea25719adc1fb8911c7799c6bd5a95df3b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/sloisel/numeric/v1.2.6/src/numeric.js", - "https://raw.githubusercontent.com/sloisel/numeric/v1.2.6/src/numeric.js", - ], - }, - ) - - filegroup_external( - name = "com_palantir_plottable", - # no @license header - licenses = ["notice"], # MIT - sha256_urls_extract = { - # Plottable doesn't have a release tarball on GitHub. Using the - # sources directly from git also requires running Node tooling - # beforehand to generate files. NPM is the only place to get it. - "e3159beb279391c47433789f22b32bac88488cfcad6c0b6ec8605ce6b0081b0d": [ - "http://mirror.bazel.build/registry.npmjs.org/plottable/-/plottable-3.1.0.tgz", - "https://registry.npmjs.org/plottable/-/plottable-3.1.0.tgz", - ], - }, - ) - - filegroup_external( - name = "io_github_cpettitt_dagre", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "6a349742a6cb219d5a2fc8d0844f6d89a6efc62e20c664450d884fc7ff2d6015": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/dagre/v0.7.4/LICENSE", - "https://raw.githubusercontent.com/cpettitt/dagre/v0.7.4/LICENSE", - ], - "7323829ddd77924a69e2b1235ded3eac30acd990da0f037e0fbd3c8e9035b50d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/dagre/v0.7.4/dist/dagre.core.js", - "https://raw.githubusercontent.com/cpettitt/dagre/v0.7.4/dist/dagre.core.js", - ], - }, - ) - - filegroup_external( - name = "io_github_cpettitt_graphlib", - licenses = ["notice"], # MIT - sha256_urls = { - "6a349742a6cb219d5a2fc8d0844f6d89a6efc62e20c664450d884fc7ff2d6015": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/LICENSE", - "https://raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/LICENSE", - ], - "772045d412b1513b549be991c2e1846c38019429d43974efcae943fbe83489bf": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/dist/graphlib.core.js", - "https://raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/dist/graphlib.core.js", - ], - }, - ) - - filegroup_external( - name = "io_github_waylonflinn_weblas", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "633f2861a9a862b9cd7967e841e14dd3527912f209d6563595774fa31e3d84cb": [ - "http://mirror.bazel.build/raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/LICENSES", - "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/LICENSE", - ], - "f138fce57f673ca8a633f4aee5ae5b6fcb6ad0de59069a42a74e996fd04d8fcc": [ - "http://mirror.bazel.build/raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js", - "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js", - ], - }, - ) - - filegroup_external( - name = "org_d3js", - # no @license header - licenses = ["notice"], # BSD-3-Clause - sha256_urls_extract = { - "b5fac5b296bc196e6aa7b59f9e33986fc44d23d59a0e211705187be9e35b943d": [ - "http://mirror.bazel.build/github.com/d3/d3/releases/download/v4.8.0/d3.zip", - "https://github.com/d3/d3/releases/download/v4.8.0/d3.zip", - ], - }, - # TODO(jart): Use srcs=["d3.js"] instead of this once supported. - generated_rule_name = "all_files", - extra_build_file_content = "\n".join([ - "filegroup(", - " name = \"org_d3js\",", - " srcs = [\"d3.js\"],", - ")", - ]), - ) - - filegroup_external( - name = "org_chromium_catapult_vulcanized_trace_viewer", - licenses = ["notice"], # BSD-3-Clause - sha256_urls = { - "f0df289ba9d03d857ad1c2f5918861376b1510b71588ffc60eff5c7a7bfedb09": [ - "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE", - "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE", - ], - "9e99e79439ea5a1471bd4dd325bd6733e133bcb3da4df4b878ed6d2aec7c8d86": [ - "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html", - "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html" - ], - }, - ) - - ############################################################################## - # TensorBoard Testing Dependencies - web_library_external( - name = "org_npmjs_registry_accessibility_developer_tools", - licenses = ["notice"], # Apache License 2.0 - sha256 = "1d6a72f401c9d53f68238c617dd43a05cd85ca5aa2e676a5b3c352711448e093", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/accessibility-developer-tools/-/accessibility-developer-tools-2.10.0.tgz", - "https://registry.npmjs.org/accessibility-developer-tools/-/accessibility-developer-tools-2.10.0.tgz", - ], - strip_prefix = "package", - path = "/accessibility-developer-tools", - suppress = ["strictDependencies"], - ) - - web_library_external( - name = "org_npmjs_registry_async", - licenses = ["notice"], # MIT - sha256 = "08655255ae810bf4d1cb1642df57658fcce823776d3ba8f4b46f4bbff6c87ece", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/async/-/async-1.5.0.tgz", - "https://registry.npmjs.org/async/-/async-1.5.0.tgz", - ], - strip_prefix = "package", - path = "/async", - ) - - web_library_external( - name = "org_npmjs_registry_chai", - licenses = ["notice"], # MIT - sha256 = "aca8137bed5bb295bd7173325b7ad604cd2aeb341d739232b4f9f0b26745be90", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/chai/-/chai-3.5.0.tgz", - "https://registry.npmjs.org/chai/-/chai-3.5.0.tgz", - ], - strip_prefix = "package", - path = "/chai", - ) - - web_library_external( - name = "org_npmjs_registry_mocha", - licenses = ["notice"], # MIT - sha256 = "13ef37a071196a2fba680799b906555d3f0ab61e80a7e8f73f93e77914590dd4", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/mocha/-/mocha-2.5.3.tgz", - "https://registry.npmjs.org/mocha/-/mocha-2.5.3.tgz", - ], - suppress = ["strictDependencies"], - strip_prefix = "package", - path = "/mocha", - ) - - web_library_external( - name = "org_npmjs_registry_sinon", - licenses = ["notice"], # BSD-3-Clause - sha256 = "49edb057695fc9019aae992bf7e677a07de7c6ce2bf9f9facde4a245045d1532", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/sinon/-/sinon-1.17.4.tgz", - "https://registry.npmjs.org/sinon/-/sinon-1.17.4.tgz", - ], - strip_prefix = "package/lib", - path = "/sinonjs", - ) - - web_library_external( - name = "org_npmjs_registry_sinon_chai", - licenses = ["notice"], # BSD-3-Clause - sha256 = "b85fc56f713832960b56fe9269ee4bb2cd41edd2ceb130b0936e5bdbed5dea63", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/sinon-chai/-/sinon-chai-2.8.0.tgz", - "https://registry.npmjs.org/sinon-chai/-/sinon-chai-2.8.0.tgz", - ], - strip_prefix = "package", - path = "/sinon-chai", - ) - - web_library_external( - name = "org_npmjs_registry_stacky", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c659e60f7957d9d80c23a7aacc4d71b19c6421a08f91174c0062de369595acae", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/stacky/-/stacky-1.3.1.tgz", - "https://registry.npmjs.org/stacky/-/stacky-1.3.1.tgz", - ], - strip_prefix = "package", - path = "/stacky", - ) - - web_library_external( - name = "org_npmjs_registry_web_component_tester", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9d4ebd4945df8a936916d4d32b7f280f2a3afa35f79e7ca8ad3ed0a42770c537", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/web-component-tester/-/web-component-tester-4.3.6.tgz", - "https://registry.npmjs.org/web-component-tester/-/web-component-tester-4.3.6.tgz", - ], - strip_prefix = "package", - path = "/web-component-tester", - suppress = [ - "absolutePaths", - "strictDependencies", - ], - deps = [ - "@com_lodash", - "@org_npmjs_registry_accessibility_developer_tools", - "@org_npmjs_registry_async", - "@org_npmjs_registry_chai", - "@org_npmjs_registry_mocha", - "@org_npmjs_registry_sinon", - "@org_npmjs_registry_sinon_chai", - "@org_npmjs_registry_stacky", - "@org_polymer_test_fixture", - ], - ) - - web_library_external( - name = "org_polymer_test_fixture", - licenses = ["notice"], # BSD-3-Clause - sha256 = "59d6cfb1187733b71275becfea181fe0aa1f734df5ff77f5850c806bbbf9a0d9", - strip_prefix = "test-fixture-2.0.1", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/test-fixture/archive/v2.0.1.tar.gz", - "https://github.com/PolymerElements/test-fixture/archive/v2.0.1.tar.gz", - ], - path = "/test-fixture", - exclude = ["test/**"], - ) - diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index 32266997a7e54c09525a60a48d2ad330941e2668..2d96406d27047ab9d518a79584cfae8b43c9feb4 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -257,6 +257,16 @@ cc_library( includes = ["include"], ) +# A creator of an empty file include/llvm/Support/VCSRevision.h. +# This is usually populated by the upstream build infrastructure, but in this +# case we leave it blank. See upstream revision r300160. +genrule( + name = "vcs_revision_gen", + srcs = [], + outs = ["include/llvm/Support/VCSRevision.h"], + cmd = "echo '' > \"$@\"", +) + # Rules that apply the LLVM tblgen tool. gentbl( name = "intrinsics_gen", @@ -376,6 +386,7 @@ llvm_target_list = [ "tbl_outs": [ ("-gen-register-bank", "lib/Target/ARM/ARMGenRegisterBank.inc"), ("-gen-register-info", "lib/Target/ARM/ARMGenRegisterInfo.inc"), + ("-gen-searchable-tables", "lib/Target/ARM/ARMGenSystemRegister.inc"), ("-gen-instr-info", "lib/Target/ARM/ARMGenInstrInfo.inc"), ("-gen-emitter", "lib/Target/ARM/ARMGenMCCodeEmitter.inc"), ("-gen-pseudo-lowering", "lib/Target/ARM/ARMGenMCPseudoLowering.inc"), @@ -453,6 +464,7 @@ llvm_target_list = [ "include/llvm/IR/Intrinsics*.td", "include/llvm/TableGen/*.td", "include/llvm/Target/*.td", + "include/llvm/Target/GlobalISel/*.td", ]), ) for target in llvm_target_list @@ -868,6 +880,7 @@ cc_library( deps = [ ":arm_desc", ":arm_info", + ":arm_utils", ":config", ":mc", ":mc_parser", @@ -886,12 +899,14 @@ cc_library( "include/llvm/Target/ARM/InstPrinter/*.h", "include/llvm/Target/ARM/InstPrinter/*.def", "include/llvm/Target/ARM/InstPrinter/*.inc", + "lib/Target/ARM/*.h", "lib/Target/ARM/InstPrinter/*.h", ]), copts = ["-Iexternal/llvm/lib/Target/ARM"], deps = [ ":arm_info", ":arm_target_gen", + ":arm_utils", ":config", ":mc", ":support", @@ -917,6 +932,7 @@ cc_library( ":arm_asm_printer", ":arm_desc", ":arm_info", + ":arm_utils", ":asm_printer", ":code_gen", ":config", @@ -1005,6 +1021,29 @@ cc_library( ], ) +cc_library( + name = "arm_utils", + srcs = glob([ + "lib/Target/ARM/Utils/*.c", + "lib/Target/ARM/Utils/*.cpp", + "lib/Target/ARM/Utils/*.inc", + "lib/Target/ARM/MCTargetDesc/*.h", + ]), + hdrs = glob([ + "include/llvm/Target/ARM/Utils/*.h", + "include/llvm/Target/ARM/Utils/*.def", + "include/llvm/Target/ARM/Utils/*.inc", + "lib/Target/ARM/Utils/*.h", + ]), + copts = ["-Iexternal/llvm/lib/Target/ARM"], + deps = [ + ":arm_target_gen", + ":config", + ":mc", + ":support", + ], +) + cc_library( name = "asm_parser", srcs = glob([ @@ -1067,6 +1106,8 @@ cc_library( "include/llvm/BinaryFormat/*.h", "include/llvm/BinaryFormat/*.def", "include/llvm/BinaryFormat/*.inc", + "include/llvm/BinaryFormat/ELFRelocs/*.def", + "include/llvm/BinaryFormat/WasmRelocs/*.def", ]), deps = [ ":config", @@ -1116,6 +1157,7 @@ cc_library( ":config", ":core", ":mc", + ":object", ":support", ], ) @@ -1165,6 +1207,7 @@ cc_library( "lib/IR/*.h", ]), hdrs = glob([ + "include/llvm/Analysis/*.def", "include/llvm/IR/*.h", "include/llvm/IR/*.def", "include/llvm/IR/*.inc", @@ -1194,6 +1237,7 @@ cc_library( "include/llvm/DebugInfo/CodeView/*.inc", ]), deps = [ + ":binary_format", ":config", ":debug_info_msf", ":support", @@ -1426,6 +1470,7 @@ cc_library( "include/llvm/MC/*.inc", ]), deps = [ + ":binary_format", ":config", ":debug_info_code_view", ":support", @@ -1921,6 +1966,8 @@ cc_library( "lib/Support/Unix/*.h", "include/llvm-c/*.h", "include/llvm/CodeGen/MachineValueType.h", + "include/llvm/BinaryFormat/COFF.h", + "include/llvm/BinaryFormat/MachO.h", "lib/Support/*.h", ]), hdrs = glob([ @@ -1931,7 +1978,9 @@ cc_library( "include/llvm/Support/ELFRelocs/*.def", "include/llvm/Support/WasmRelocs/*.def", ]) + [ + "include/llvm/BinaryFormat/MachO.def", "include/llvm/Support/DataTypes.h", + "include/llvm/Support/VCSRevision.h", "include/llvm/ExecutionEngine/ObjectMemoryBuffer.h", ], deps = [ @@ -1975,6 +2024,8 @@ cc_library( "lib/Target/*.h", ]), hdrs = glob([ + "include/llvm/CodeGen/*.h", + "include/llvm/CodeGen/*.def", "include/llvm/Target/*.h", "include/llvm/Target/*.def", "include/llvm/Target/*.inc", diff --git a/third_party/lmdb.BUILD b/third_party/lmdb.BUILD index 7c6e3dc3f0531f7e2dc3c4ad782a6a02a6b4e514..9b3e1d97c83b44bba97e5513ae41c1511cf33ce7 100644 --- a/third_party/lmdb.BUILD +++ b/third_party/lmdb.BUILD @@ -19,8 +19,8 @@ cc_library( "-w", ], linkopts = select({ - ":windows": ["-Wl,advapi32.lib"], # InitializeSecurityDescriptor, SetSecurityDescriptorDacl - ":windows_msvc": ["-Wl,advapi32.lib"], + ":windows": ["-DEFAULTLIB:advapi32.lib"], # InitializeSecurityDescriptor, SetSecurityDescriptorDacl + ":windows_msvc": ["-DEFAULTLIB:advapi32.lib"], "//conditions:default": ["-lpthread"], }), visibility = ["//visibility:public"], diff --git a/third_party/markdown.BUILD b/third_party/markdown.BUILD deleted file mode 100644 index fa3e85d5304083ed0de521c93c5ea1df1f477349..0000000000000000000000000000000000000000 --- a/third_party/markdown.BUILD +++ /dev/null @@ -1,15 +0,0 @@ -# Description: -# Markdown processor - -package(default_visibility = ["//visibility:public"]) - -# This software says they use a BSD license. -licenses(["notice"]) - -exports_files(["LICENSE.md"]) - -py_library( - name = "org_pythonhosted_markdown", - srcs = glob(["markdown/**/*.py"]), - srcs_version = "PY2AND3", -) diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD index 8c86766effa97a08f6089194a5d9202da0e003b3..b27d341404c4ee1ca1e87ff3b9f427ec52eba739 100644 --- a/third_party/mkl/BUILD +++ b/third_party/mkl/BUILD @@ -1,4 +1,6 @@ -licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like TODO +licenses(["notice"]) # 3-Clause BSD + +exports_files(["LICENSE"]) config_setting( name = "using_mkl", @@ -16,10 +18,9 @@ load( cc_library( name = "intel_binary_blob", srcs = if_mkl([ - "libdl.so.2", - "libmklml_intel.so", - "libiomp5.so", + "@mkl//:libmklml_intel.so", + "@mkl//:libiomp5.so", ]), - includes = ["."], visibility = ["//visibility:public"], + deps = ["@mkl//:mkl_headers"], ) diff --git a/third_party/mkl/LICENSE b/third_party/mkl/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9c8f3ea0871e0bfe81da0fa6e7c1d7d156dc380e --- /dev/null +++ b/third_party/mkl/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl index 9a28b312c2de68481df774875330fd570f336ad3..533c0766c71a18e614f2f101a4e74b7f35fd26c3 100644 --- a/third_party/mkl/build_defs.bzl +++ b/third_party/mkl/build_defs.bzl @@ -1,4 +1,16 @@ -# Macros for building MKL code. +# -*- Python -*- +"""Skylark macros for MKL. +if_mkl is a conditional to check if MKL is enabled or not. + +mkl_repository is a repository rule for creating MKL repository rule that can +be pointed to either a local folder, or download it from the internet. +mkl_repository depends on the following environment variables: + * `TF_MKL_ROOT`: The root folder where a copy of libmkl is located. +""" + + +_TF_MKL_ROOT = "TF_MKL_ROOT" + def if_mkl(if_true, if_false = []): """Shorthand for select()'ing on whether we're building with MKL. @@ -11,3 +23,46 @@ def if_mkl(if_true, if_false = []): "//third_party/mkl:using_mkl": if_true, "//conditions:default": if_false }) + + +def _enable_local_mkl(repository_ctx): + return _TF_MKL_ROOT in repository_ctx.os.environ + + +def _mkl_autoconf_impl(repository_ctx): + """Implementation of the local_mkl_autoconf repository rule.""" + + if _enable_local_mkl(repository_ctx): + # Symlink lib and include local folders. + mkl_root = repository_ctx.os.environ[_TF_MKL_ROOT] + mkl_lib_path = "%s/lib" % mkl_root + repository_ctx.symlink(mkl_lib_path, "lib") + mkl_include_path = "%s/include" % mkl_root + repository_ctx.symlink(mkl_include_path, "include") + mkl_license_path = "%s/license.txt" % mkl_root + repository_ctx.symlink(mkl_license_path, "license.txt") + else: + # setup remote mkl repository. + repository_ctx.download_and_extract( + repository_ctx.attr.urls, + sha256=repository_ctx.attr.sha256, + stripPrefix=repository_ctx.attr.strip_prefix, + ) + + # Also setup BUILD file. + repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD") + + +mkl_repository = repository_rule( + implementation = _mkl_autoconf_impl, + environ = [ + _TF_MKL_ROOT, + ], + attrs = { + "build_file": attr.label(), + "repository": attr.string(), + "urls": attr.string_list(default = []), + "sha256": attr.string(default = ""), + "strip_prefix": attr.string(default = ""), + }, +) diff --git a/third_party/mkl/mkl.BUILD b/third_party/mkl/mkl.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..8db97232e156b46091b379b0771239f55d6ea5ad --- /dev/null +++ b/third_party/mkl/mkl.BUILD @@ -0,0 +1,30 @@ +licenses(["notice"]) # 3-Clause BSD + +exports_files(["license.txt"]) + +filegroup( + name = "LICENSE", + srcs = [ + "license.txt", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "mkl_headers", + srcs = glob(["include/*"]), + includes = ["include"], + visibility = ["//visibility:public"], +) + +filegroup( + name = "libmklml_intel.so", + srcs = ["lib/libmklml_intel.so"], + visibility = ["//visibility:public"], +) + +filegroup( + name = "libiomp5.so", + srcs = ["lib/libiomp5.so"], + visibility = ["//visibility:public"], +) diff --git a/third_party/nanopb.BUILD b/third_party/nanopb.BUILD deleted file mode 100644 index d21866911b862f0d4adf76c3a07e2732128a6102..0000000000000000000000000000000000000000 --- a/third_party/nanopb.BUILD +++ /dev/null @@ -1,23 +0,0 @@ -# Description: -# Nanopb, a tiny ANSI C protobuf implementation for use on embedded devices. - -licenses(["notice"]) # zlib license - -exports_files(["LICENSE.txt"]) - -cc_library( - name = "nanopb", - srcs = [ - "pb_common.c", - "pb_decode.c", - "pb_encode.c", - ], - hdrs = [ - "pb.h", - "pb_common.h", - "pb_decode.h", - "pb_encode.h", - ], - includes = ["."], - visibility = ["//visibility:public"], -) diff --git a/third_party/polymer.bzl b/third_party/polymer.bzl deleted file mode 100644 index bd6e05803cf39192092fb20015c7abe520e8903e..0000000000000000000000000000000000000000 --- a/third_party/polymer.bzl +++ /dev/null @@ -1,1335 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard Polymer Dependencies - -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") - -def tensorboard_polymer_workspace(): - web_library_external( - name = "org_polymer_font_roboto", - licenses = ["notice"], # BSD-3-Clause - sha256 = "fae51429b56a4a4c15f1f0c23b733c7095940cc9c04c275fa7adb3bf055b23b3", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/font-roboto/archive/v1.0.1.tar.gz", - "https://github.com/PolymerElements/font-roboto/archive/v1.0.1.tar.gz", - ], - strip_prefix = "font-roboto-1.0.1", - path = "/font-roboto", - srcs = ["roboto.html"], - ) - - web_library_external( - name = "org_polymer_hydrolysis", - licenses = ["notice"], # BSD-3-Clause - sha256 = "703b50f6b00f9e0546b5a3451da57bb20f77a166e27e4967923b9e835bab9b80", - urls = [ - "http://mirror.bazel.build/github.com/Polymer/polymer-analyzer/archive/v1.19.3.tar.gz", - "https://github.com/Polymer/polymer-analyzer/archive/v1.19.3.tar.gz", - ], - strip_prefix = "polymer-analyzer-1.19.3", - path = "/hydrolysis", - srcs = [ - "hydrolysis-analyzer.html", - "hydrolysis.html", - "hydrolysis.js", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_a11y_announcer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6bce143db7a374a68535ec8b861a5f30e81f2f1e4ee36a55bda2a891f6fd2818", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-a11y-announcer/archive/v1.0.5.tar.gz", - "https://github.com/PolymerElements/iron-a11y-announcer/archive/v1.0.5.tar.gz", - ], - strip_prefix = "iron-a11y-announcer-1.0.5", - path = "/iron-a11y-announcer", - srcs = ["iron-a11y-announcer.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_a11y_keys_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6823efc47a83208fd51d39c5a1d3eb0c0bebc705df1ce01310509da22a13ebd2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz", - "https://github.com/PolymerElements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz", - ], - strip_prefix = "iron-a11y-keys-behavior-1.1.8", - path = "/iron-a11y-keys-behavior", - srcs = ["iron-a11y-keys-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_ajax", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9162d8af4611e911ac3ebbfc08bb7038ac04f6e79a9287b1476fe36ad6770bc5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-ajax/archive/v1.2.0.tar.gz", - "https://github.com/PolymerElements/iron-ajax/archive/v1.2.0.tar.gz", - ], - strip_prefix = "iron-ajax-1.2.0", - path = "/iron-ajax", - srcs = [ - "iron-ajax.html", - "iron-request.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_promise_polyfill", - ], - ) - - web_library_external( - name = "org_polymer_iron_autogrow_textarea", - licenses = ["notice"], # BSD-3-Clause - sha256 = "50bbb901d2c8f87462e3552e3d671a552faa12c37c485e548d7a234ebffbc427", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-autogrow-textarea/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/iron-autogrow-textarea/archive/v1.0.12.tar.gz", - ], - strip_prefix = "iron-autogrow-textarea-1.0.12", - path = "/iron-autogrow-textarea", - srcs = ["iron-autogrow-textarea.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_behaviors", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a1e8d4b7a13f3d36beba9c2a6b186ed33a53e6af2e79f98c1fcc7e85e7b53f89", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-behaviors/archive/v1.0.17.tar.gz", - "https://github.com/PolymerElements/iron-behaviors/archive/v1.0.17.tar.gz", - ], - strip_prefix = "iron-behaviors-1.0.17", - path = "/iron-behaviors", - srcs = [ - "iron-button-state.html", - "iron-control-state.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_checked_element_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "539a0e1c4df0bc702d3bd342388e4e56c77ec4c2066cce69e41426a69f92e8bd", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-checked-element-behavior/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/iron-checked-element-behavior/archive/v1.0.4.tar.gz", - ], - strip_prefix = "iron-checked-element-behavior-1.0.4", - path = "/iron-checked-element-behavior", - srcs = ["iron-checked-element-behavior.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_component_page", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3636e8b9a1f229fc33b5aad3933bd02a9825f66e679a0be31855d7c8245c4b4b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-component-page/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/iron-component-page/archive/v1.1.4.tar.gz", - ], - strip_prefix = "iron-component-page-1.1.4", - path = "/iron-component-page", - srcs = ["iron-component-page.html"], - deps = [ - "@org_polymer", - "@org_polymer_hydrolysis", - "@org_polymer_iron_ajax", - "@org_polymer_iron_doc_viewer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icons", - "@org_polymer_iron_selector", - "@org_polymer_paper_header_panel", - "@org_polymer_paper_styles", - "@org_polymer_paper_toolbar", - ], - ) - - web_library_external( - name = "org_polymer_iron_collapse", - licenses = ["notice"], # BSD-3-Clause - sha256 = "275808994a609a2f9923e2dd2db1957945ab141ba840eadc33f19e1f406d600e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-collapse/archive/v1.0.8.tar.gz", - "https://github.com/PolymerElements/iron-collapse/archive/v1.0.8.tar.gz", - ], - strip_prefix = "iron-collapse-1.0.8", - path = "/iron-collapse", - srcs = ["iron-collapse.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_resizable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_demo_helpers", - licenses = ["notice"], # BSD-3-Clause - sha256 = "aa7458492a6ac3d1f6344640a4c2ab07bce64e7ad0422b83b5d665707598cce6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-demo-helpers/archive/v1.1.0.tar.gz", - "https://github.com/PolymerElements/iron-demo-helpers/archive/v1.1.0.tar.gz", - ], - strip_prefix = "iron-demo-helpers-1.1.0", - path = "/iron-demo-helpers", - srcs = [ - "demo-pages-shared-styles.html", - "demo-snippet.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icons", - "@org_polymer_marked_element", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_styles", - "@org_polymer_prism_element", - ], - ) - - web_library_external( - name = "org_polymer_iron_doc_viewer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f0e9dfbbcd94d7e88ce82cb61e615406ace63c185fee9396f7f182206ca5cc9a", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-doc-viewer/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/iron-doc-viewer/archive/v1.0.12.tar.gz", - ], - strip_prefix = "iron-doc-viewer-1.0.12", - path = "/iron-doc-viewer", - srcs = [ - "iron-doc-property-styles.html", - "iron-doc-property.html", - "iron-doc-viewer-styles.html", - "iron-doc-viewer.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_marked_element", - "@org_polymer_paper_button", - "@org_polymer_paper_styles", - "@org_polymer_prism_element", - ], - ) - - web_library_external( - name = "org_polymer_iron_dropdown", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f7e4a31d096d10d8af1920397695cb17f3eb1cbe5e5ff91a861dabfcc085f376", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-dropdown/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/iron-dropdown/archive/v1.4.0.tar.gz", - ], - strip_prefix = "iron-dropdown-1.4.0", - path = "/iron-dropdown", - srcs = [ - "iron-dropdown.html", - "iron-dropdown-scroll-manager.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_overlay_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_neon_animation", - ], - ) - - web_library_external( - name = "org_polymer_iron_fit_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "10132a2ea309a37c4c07b8fead71f64abc588ee6107931e34680f5f36dd8291e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-fit-behavior/archive/v1.2.5.tar.gz", - "https://github.com/PolymerElements/iron-fit-behavior/archive/v1.2.5.tar.gz", - ], - strip_prefix = "iron-fit-behavior-1.2.5", - path = "/iron-fit-behavior", - srcs = ["iron-fit-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_flex_layout", - licenses = ["notice"], # BSD-3-Clause - sha256 = "79287f6ca1c2d4e003f68b88fe19d03a1b6a0011e2b4cae579fe4d1474163a2e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-flex-layout/archive/v1.3.0.tar.gz", - "https://github.com/PolymerElements/iron-flex-layout/archive/v1.3.0.tar.gz", - ], - strip_prefix = "iron-flex-layout-1.3.0", - path = "/iron-flex-layout", - srcs = [ - "classes/iron-flex-layout.html", - "classes/iron-shadow-flex-layout.html", - "iron-flex-layout.html", - "iron-flex-layout-classes.html", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_form_element_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "1dd9371c638e5bc2ecba8a64074aa680dfb8712198e9612f9ed24d387efc8f26", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-form-element-behavior/archive/v1.0.6.tar.gz", - "https://github.com/PolymerElements/iron-form-element-behavior/archive/v1.0.6.tar.gz", - ], - strip_prefix = "iron-form-element-behavior-1.0.6", - path = "/iron-form-element-behavior", - srcs = ["iron-form-element-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_icon", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9ed58a69159a02c07a6050d242e6d4e585a29f3245b8c8c390cfd52ddb786dc4", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-icon/archive/v1.0.11.tar.gz", - "https://github.com/PolymerElements/iron-icon/archive/v1.0.11.tar.gz", - ], - strip_prefix = "iron-icon-1.0.11", - path = "/iron-icon", - srcs = ["iron-icon.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_iron_icons", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3b18542c147c7923dc3a36b1a51984a73255d610f297d43c9aaccc52859bd0d0", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-icons/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/iron-icons/archive/v1.1.3.tar.gz", - ], - strip_prefix = "iron-icons-1.1.3", - path = "/iron-icons", - srcs = [ - "av-icons.html", - "communication-icons.html", - "device-icons.html", - "editor-icons.html", - "hardware-icons.html", - "image-icons.html", - "iron-icons.html", - "maps-icons.html", - "notification-icons.html", - "places-icons.html", - "social-icons.html", - ], - deps = [ - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - ], - ) - - web_library_external( - name = "org_polymer_iron_iconset_svg", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7e3925b7e63a7d22524c4b43ce16ab80d06a576649644783643c11a003284368", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-iconset-svg/archive/v1.1.0.tar.gz", - "https://github.com/PolymerElements/iron-iconset-svg/archive/v1.1.0.tar.gz", - ], - strip_prefix = "iron-iconset-svg-1.1.0", - path = "/iron-iconset-svg", - srcs = ["iron-iconset-svg.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_iron_input", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c505101ead08ab25526b1f49baecc8c28b4221b92a65e7334c783bdc81553c36", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-input/archive/1.0.10.tar.gz", - "https://github.com/PolymerElements/iron-input/archive/1.0.10.tar.gz", - ], - strip_prefix = "iron-input-1.0.10", - path = "/iron-input", - srcs = ["iron-input.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_announcer", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_list", - licenses = ["notice"], # BSD-3-Clause - sha256 = "72a6530b9f0ad5557f5d287845792a0ada74d8b159198e27f940e226313dc116", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-list/archive/v1.3.9.tar.gz", - "https://github.com/PolymerElements/iron-list/archive/v1.3.9.tar.gz", - ], - strip_prefix = "iron-list-1.3.9", - path = "/iron-list", - srcs = ["iron-list.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_iron_scroll_target_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_menu_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ad27889343bc9a709258b073f69abc028bb1ffd3fdb975cd2d3939f7f5d7bb6c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-menu-behavior/archive/v1.1.10.tar.gz", - "https://github.com/PolymerElements/iron-menu-behavior/archive/v1.1.10.tar.gz", - ], - strip_prefix = "iron-menu-behavior-1.1.10", - path = "/iron-menu-behavior", - srcs = [ - "iron-menu-behavior.html", - "iron-menubar-behavior.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_selector", - ], - ) - - web_library_external( - name = "org_polymer_iron_meta", - licenses = ["notice"], # BSD-3-Clause - sha256 = "fb05e6031bae6b4effe5f15d44b3f548d5807f9e3b3aa2442ba17cf4b8b84361", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-meta/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/iron-meta/archive/v1.1.1.tar.gz", - ], - strip_prefix = "iron-meta-1.1.1", - path = "/iron-meta", - srcs = ["iron-meta.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_overlay_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3df5b54ff2e0510c87a2aff8c9d730d3fe83d3d11277cc1a49fa29b549acb46c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-overlay-behavior/archive/v1.10.1.tar.gz", - "https://github.com/PolymerElements/iron-overlay-behavior/archive/v1.10.1.tar.gz", - ], - strip_prefix = "iron-overlay-behavior-1.10.1", - path = "/iron-overlay-behavior", - srcs = [ - "iron-focusables-helper.html", - "iron-overlay-backdrop.html", - "iron-overlay-behavior.html", - "iron-overlay-manager.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_fit_behavior", - "@org_polymer_iron_resizable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_range_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "b2f2b6d52284542330bd30b586e217926eb0adec5e13934a3cef557717c22dc2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-range-behavior/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/iron-range-behavior/archive/v1.0.4.tar.gz", - ], - strip_prefix = "iron-range-behavior-1.0.4", - path = "/iron-range-behavior", - srcs = ["iron-range-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_resizable_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a87a78ee9223c2f6afae7fc94a3ff91cbce6f7e2a7ed3f2979af7945c9281616", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-resizable-behavior/archive/v1.0.3.tar.gz", - "https://github.com/PolymerElements/iron-resizable-behavior/archive/v1.0.3.tar.gz", - ], - strip_prefix = "iron-resizable-behavior-1.0.3", - path = "/iron-resizable-behavior", - srcs = ["iron-resizable-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_scroll_target_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "d0de0c804b1ec91d814754144afd9da1cdb082690de88bd5e47fd5f41990746f", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz", - "https://github.com/PolymerElements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz", - ], - strip_prefix = "iron-scroll-target-behavior-1.0.3", - path = "/iron-scroll-target-behavior", - srcs = ["iron-scroll-target-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_selector", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ba28a47443bad3b744611c9d7a79fb21dbdf2e35edc5ef8f812e2dcd72b16747", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-selector/archive/v1.5.2.tar.gz", - "https://github.com/PolymerElements/iron-selector/archive/v1.5.2.tar.gz", - ], - strip_prefix = "iron-selector-1.5.2", - path = "/iron-selector", - srcs = [ - "iron-multi-selectable.html", - "iron-selectable.html", - "iron-selection.html", - "iron-selector.html", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_validatable_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "aef4901e68043824f36104799269573dd345ffaac494186e466fdc79c06fdb63", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-validatable-behavior/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/iron-validatable-behavior/archive/v1.1.1.tar.gz", - ], - strip_prefix = "iron-validatable-behavior-1.1.1", - path = "/iron-validatable-behavior", - srcs = ["iron-validatable-behavior.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_marked", - licenses = ["notice"], # MIT - sha256 = "93d30bd593736ca440938d77808b7ef5972da0f3fcfe4ae63ae7b4ce117da2cb", - urls = [ - "http://mirror.bazel.build/github.com/chjj/marked/archive/v0.3.2.zip", - "https://github.com/chjj/marked/archive/v0.3.2.zip", - ], - strip_prefix = "marked-0.3.2", - path = "/marked", - srcs = ["lib/marked.js"], - ) - - web_library_external( - name = "org_polymer_marked_element", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7547616df95f8b903757e6afbabfcdba5322c2bcec3f17c726b8bba5adf4bc5f", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/marked-element/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/marked-element/archive/v1.1.3.tar.gz", - ], - strip_prefix = "marked-element-1.1.3", - path = "/marked-element", - srcs = [ - "marked-element.html", - "marked-import.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_marked", - ], - ) - - web_library_external( - name = "org_polymer_neon_animation", - licenses = ["notice"], # BSD-3-Clause - sha256 = "8800c314a76b2da190a2b203259c1091f6d38e0057ed37c2a3d0b734980fa9a5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/neon-animation/archive/v1.2.2.tar.gz", - "https://github.com/PolymerElements/neon-animation/archive/v1.2.2.tar.gz", - ], - strip_prefix = "neon-animation-1.2.2", - path = "/neon-animation", - srcs = [ - "animations/cascaded-animation.html", - "animations/fade-in-animation.html", - "animations/fade-out-animation.html", - "animations/hero-animation.html", - "animations/opaque-animation.html", - "animations/reverse-ripple-animation.html", - "animations/ripple-animation.html", - "animations/scale-down-animation.html", - "animations/scale-up-animation.html", - "animations/slide-down-animation.html", - "animations/slide-from-bottom-animation.html", - "animations/slide-from-left-animation.html", - "animations/slide-from-right-animation.html", - "animations/slide-from-top-animation.html", - "animations/slide-left-animation.html", - "animations/slide-right-animation.html", - "animations/slide-up-animation.html", - "animations/transform-animation.html", - "neon-animatable.html", - "neon-animatable-behavior.html", - "neon-animated-pages.html", - "neon-animation.html", - "neon-animation-behavior.html", - "neon-animation-runner-behavior.html", - "neon-animations.html", - "neon-shared-element-animatable-behavior.html", - "neon-shared-element-animation-behavior.html", - "web-animations.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_iron_selector", - "@org_polymer_web_animations_js", - ], - ) - - web_library_external( - name = "org_polymer_paper_behaviors", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7cfcb9082ef9909da262df6b5c120bc62dbeaff278cb563e8fc60465ddd387e5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-behaviors/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/paper-behaviors/archive/v1.0.12.tar.gz", - ], - strip_prefix = "paper-behaviors-1.0.12", - path = "/paper-behaviors", - srcs = [ - "paper-button-behavior.html", - "paper-checked-element-behavior.html", - "paper-inky-focus-behavior.html", - "paper-ripple-behavior.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_checked_element_behavior", - "@org_polymer_paper_ripple", - ], - ) - - web_library_external( - name = "org_polymer_paper_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "896c0a7e34bfcce63fc23c63e105ed9c4d62fa3a6385b7161e1e5cd4058820a6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-button/archive/v1.0.11.tar.gz", - "https://github.com/PolymerElements/paper-button/archive/v1.0.11.tar.gz", - ], - strip_prefix = "paper-button-1.0.11", - path = "/paper-button", - srcs = ["paper-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_material", - "@org_polymer_paper_ripple", - ], - ) - - web_library_external( - name = "org_polymer_paper_checkbox", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6828a6954a048b1230fbd2606faffbae950ba1d042175b96ec50ae355786a166", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-checkbox/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/paper-checkbox/archive/v1.4.0.tar.gz", - ], - strip_prefix = "paper-checkbox-1.4.0", - path = "/paper-checkbox", - srcs = ["paper-checkbox.html"], - deps = [ - "@org_polymer", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c6a9709e7f528d03dcd574503c18b72d4751ca30017346d16e6a791d37ed9259", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/paper-dialog/archive/v1.0.4.tar.gz", - ], - strip_prefix = "paper-dialog-1.0.4", - path = "/paper-dialog", - srcs = ["paper-dialog.html"], - deps = [ - "@org_polymer", - "@org_polymer_neon_animation", - "@org_polymer_paper_dialog_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a7e0e27ce63554bc14f384cf94bcfa24da8dc5f5120dfd565f45e166261aee40", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog-behavior/archive/v1.2.5.tar.gz", - "https://github.com/PolymerElements/paper-dialog-behavior/archive/v1.2.5.tar.gz", - ], - strip_prefix = "paper-dialog-behavior-1.2.5", - path = "/paper-dialog-behavior", - srcs = [ - "paper-dialog-behavior.html", - "paper-dialog-common.css", - "paper-dialog-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_overlay_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog_scrollable", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a2e69283e7674f782c44d811387a0f8da2d01fac0172743d1add65e253e6b5ff", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog-scrollable/archive/1.1.5.tar.gz", - "https://github.com/PolymerElements/paper-dialog-scrollable/archive/1.1.5.tar.gz", - ], - strip_prefix = "paper-dialog-scrollable-1.1.5", - path = "/paper-dialog-scrollable", - srcs = ["paper-dialog-scrollable.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_dialog_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dropdown_menu", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9d88f654ec03ee9be211df9e69bede9e8a22b51bf1dbcc63b79762e4256d81ad", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dropdown-menu/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/paper-dropdown-menu/archive/v1.4.0.tar.gz", - ], - strip_prefix = "paper-dropdown-menu-1.4.0", - path = "/paper-dropdown-menu", - srcs = [ - "paper-dropdown-menu.html", - "paper-dropdown-menu-icons.html", - "paper-dropdown-menu-light.html", - "paper-dropdown-menu-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - "@org_polymer_iron_validatable_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_input", - "@org_polymer_paper_menu_button", - "@org_polymer_paper_ripple", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_header_panel", - licenses = ["notice"], # BSD-3-Clause - sha256 = "0db4bd8a4bf6f20dcd0dffb4f907b31c93a8647c9c021344239cf30b40b87075", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-header-panel/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-header-panel/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-header-panel-1.1.4", - path = "/paper-header-panel", - srcs = ["paper-header-panel.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - ], - ) - - web_library_external( - name = "org_polymer_paper_icon_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9cba5bcfd6aeb4c41581c1392c678cf2278d360e9d122f4d9db54a9ebb404496", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-icon-button/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/paper-icon-button/archive/v1.1.3.tar.gz", - ], - strip_prefix = "paper-icon-button-1.1.3", - path = "/paper-icon-button", - srcs = [ - "paper-icon-button.html", - "paper-icon-button-light.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_icon", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_input", - licenses = ["notice"], # BSD-3-Clause - sha256 = "17c3dea9bb1c2026cc61324696c6c774214a0dc37686b91ca214a6af550994db", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-input/archive/v1.1.18.tar.gz", - "https://github.com/PolymerElements/paper-input/archive/v1.1.18.tar.gz", - ], - strip_prefix = "paper-input-1.1.18", - path = "/paper-input", - srcs = [ - "paper-input.html", - "paper-input-addon-behavior.html", - "paper-input-behavior.html", - "paper-input-char-counter.html", - "paper-input-container.html", - "paper-input-error.html", - "paper-textarea.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_autogrow_textarea", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_input", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_item", - licenses = ["notice"], # BSD-3-Clause - sha256 = "12ee0dcb61b0d5721c5988571f6974d7b2211e97724f4195893fbcc9058cdac8", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-item/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-item/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-item-1.1.4", - path = "/paper-item", - srcs = [ - "paper-icon-item.html", - "paper-item.html", - "paper-item-behavior.html", - "paper-item-body.html", - "paper-item-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_listbox", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3cb35f4fe9a3f15185a9e91711dba8f27e9291c8cd371ebf1be21b8f1d5f65fb", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-listbox/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-listbox/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-listbox-1.1.2", - path = "/paper-listbox", - srcs = ["paper-listbox.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_menu_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_material", - licenses = ["notice"], # BSD-3-Clause - sha256 = "09f6c8bd6ddbea2be541dc86306efe41cdfb31bec0b69d35a5dc29772bbc8506", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-material/archive/v1.0.6.tar.gz", - "https://github.com/PolymerElements/paper-material/archive/v1.0.6.tar.gz", - ], - strip_prefix = "paper-material-1.0.6", - path = "/paper-material", - srcs = [ - "paper-material.html", - "paper-material-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_menu", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a3cee220926e315f7412236b3628288774694447c0da4428345f36d0f127ba3b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-menu/archive/v1.2.2.tar.gz", - "https://github.com/PolymerElements/paper-menu/archive/v1.2.2.tar.gz", - ], - strip_prefix = "paper-menu-1.2.2", - path = "/paper-menu", - srcs = [ - "paper-menu.html", - "paper-menu-shared-styles.html", - "paper-submenu.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_collapse", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_menu_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_menu_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "be3290c288a2bd4f9887213db22c75add99cc29ff4d088100c0bc4eb0e57997b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-menu-button/archive/v1.5.1.tar.gz", - "https://github.com/PolymerElements/paper-menu-button/archive/v1.5.1.tar.gz", - ], - strip_prefix = "paper-menu-button-1.5.1", - path = "/paper-menu-button", - srcs = [ - "paper-menu-button.html", - "paper-menu-button-animations.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_dropdown", - "@org_polymer_neon_animation", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_progress", - licenses = ["notice"], # BSD-3-Clause - sha256 = "2b6776b2f023c1f344feea17ba29b58d879e46f8ed43b7256495054b5183fff6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-progress/archive/v1.0.9.tar.gz", - "https://github.com/PolymerElements/paper-progress/archive/v1.0.9.tar.gz", - ], - strip_prefix = "paper-progress-1.0.9", - path = "/paper-progress", - srcs = ["paper-progress.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_range_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_radio_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6e911d0c308aa388136b3af79d1bdcbe5a1f4159cbc79d71efb4ff3b6c0b4e91", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-radio-button/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-radio-button/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-radio-button-1.1.2", - path = "/paper-radio-button", - srcs = ["paper-radio-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_radio_group", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7885ad1f81e9dcc03dcea4139b54a201ff55c18543770cd44f94530046c9e163", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-radio-group/archive/v1.0.9.tar.gz", - "https://github.com/PolymerElements/paper-radio-group/archive/v1.0.9.tar.gz", - ], - strip_prefix = "paper-radio-group-1.0.9", - path = "/paper-radio-group", - srcs = ["paper-radio-group.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_selector", - "@org_polymer_paper_radio_button", - ], - ) - - web_library_external( - name = "org_polymer_paper_ripple", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ba76bfb1c737260a8a103d3ca97faa1f7c3288c7db9b2519f401b7a782147c09", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-ripple/archive/v1.0.5.tar.gz", - "https://github.com/PolymerElements/paper-ripple/archive/v1.0.5.tar.gz", - ], - strip_prefix = "paper-ripple-1.0.5", - path = "/paper-ripple", - srcs = ["paper-ripple.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_slider", - licenses = ["notice"], # BSD-3-Clause - sha256 = "08e7c541dbf5d2e959208810bfc03188e82ced87e4d30d325172967f67962c3c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-slider/archive/v1.0.10.tar.gz", - "https://github.com/PolymerElements/paper-slider/archive/v1.0.10.tar.gz", - ], - strip_prefix = "paper-slider-1.0.10", - path = "/paper-slider", - srcs = ["paper-slider.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_range_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_input", - "@org_polymer_paper_progress", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_spinner", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6a752907fab7899cbeed15b478e7b9299047c15fbf9d1561d6eb4d204bdbd178", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-spinner/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/paper-spinner/archive/v1.1.1.tar.gz", - ], - strip_prefix = "paper-spinner-1.1.1", - path = "/paper-spinner", - srcs = [ - "paper-spinner.html", "paper-spinner-behavior.html", - "paper-spinner-lite.html", "paper-spinner-styles.html" - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_styles", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6d26b0a4c286402098853dc7388f6b22f30dfb7a74e47b34992ac03380144bb2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-styles/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-styles/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-styles-1.1.4", - path = "/paper-styles", - srcs = [ - "classes/global.html", - "classes/shadow.html", - "classes/shadow-layout.html", - "classes/typography.html", - "color.html", - "default-theme.html", - "demo.css", - "demo-pages.html", - "paper-styles.html", - "paper-styles-classes.html", - "shadow.html", - "typography.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_font_roboto", - "@org_polymer_iron_flex_layout", - ], - ) - - web_library_external( - name = "org_polymer_paper_tabs", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c23b6a5221db35e5b1ed3eb8e8696b952572563e285adaec96aba1e3134db825", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-tabs/archive/v1.7.0.tar.gz", - "https://github.com/PolymerElements/paper-tabs/archive/v1.7.0.tar.gz", - ], - strip_prefix = "paper-tabs-1.7.0", - path = "/paper-tabs", - srcs = [ - "paper-tab.html", - "paper-tabs.html", - "paper-tabs-icons.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - "@org_polymer_iron_menu_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_toast", - licenses = ["notice"], # BSD-3-Clause - sha256 = "55f623712ed1f2bae6d6fadc522a2458e083ccd44cc0a907672547e7b10758a9", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toast/archive/v1.3.0.tar.gz", - "https://github.com/PolymerElements/paper-toast/archive/v1.3.0.tar.gz", - ], - strip_prefix = "paper-toast-1.3.0", - path = "/paper-toast", - srcs = ["paper-toast.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_announcer", - "@org_polymer_iron_overlay_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_toggle_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4aa7cf0396fa2994a8bc2ac6e8428f48b07b945bb7c41bd52041ef5827b45de3", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toggle-button/archive/v1.2.0.tar.gz", - "https://github.com/PolymerElements/paper-toggle-button/archive/v1.2.0.tar.gz", - ], - strip_prefix = "paper-toggle-button-1.2.0", - path = "/paper-toggle-button", - srcs = ["paper-toggle-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_toolbar", - licenses = ["notice"], # BSD-3-Clause - sha256 = "dbddffc0654d9fb5fb48843087eebe16bf7a134902495a664c96c11bf8a2c63d", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toolbar/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-toolbar/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-toolbar-1.1.4", - path = "/paper-toolbar", - srcs = ["paper-toolbar.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_tooltip", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4c6667acf01f73da14c3cbc0aa574bf14280304567987ee0314534328377d2ad", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-tooltip/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-tooltip/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-tooltip-1.1.2", - path = "/paper-tooltip", - srcs = ["paper-tooltip.html"], - deps = [ - "@org_polymer", - "@org_polymer_neon_animation", - ], - ) - - web_library_external( - name = "org_polymer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "07a9e62ffb52193da3af09adda2fbac5cc690439978520e2d03e783863f65f91", - strip_prefix = "polymer-1.7.0", - urls = [ - "http://mirror.bazel.build/github.com/polymer/polymer/archive/v1.7.0.tar.gz", - "https://github.com/polymer/polymer/archive/v1.7.0.tar.gz", - ], - path = "/polymer", - srcs = [ - "polymer.html", - "polymer-micro.html", - "polymer-mini.html", - ], - ) - - web_library_external( - name = "org_polymer_prism", - licenses = ["notice"], # MIT - sha256 = "e06eb54f2a80e6b3cd0bd4d59f900423bcaee53fc03998a056df63740c684683", - urls = [ - "http://mirror.bazel.build/github.com/PrismJS/prism/archive/abee2b7587f1925e57777044270e2a1860810994.tar.gz", - "https://github.com/PrismJS/prism/archive/abee2b7587f1925e57777044270e2a1860810994.tar.gz", - ], - strip_prefix = "prism-abee2b7587f1925e57777044270e2a1860810994", - path = "/prism", - srcs = [ - "prism.js", - "themes/prism.css", - ], - ) - - web_library_external( - name = "org_polymer_prism_element", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ad70bf9cd5bbdf525d465e1b0658867ab4022193eb9c74087a839044b46312b4", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/prism-element/archive/1.0.4.tar.gz", - "https://github.com/PolymerElements/prism-element/archive/1.0.4.tar.gz", - ], - strip_prefix = "prism-element-1.0.4", - path = "/prism-element", - srcs = [ - "prism-highlighter.html", - "prism-import.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_prism", - ], - ) - - web_library_external( - name = "org_polymer_promise_polyfill", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4495450e5d884c3e16b537b43afead7f84d17c7dc061bcfcbf440eac083e4ef5", - strip_prefix = "promise-polyfill-1.0.0", - urls = [ - "http://mirror.bazel.build/github.com/PolymerLabs/promise-polyfill/archive/v1.0.0.tar.gz", - "https://github.com/PolymerLabs/promise-polyfill/archive/v1.0.0.tar.gz", - ], - path = "/promise-polyfill", - srcs = [ - "Promise.js", - "Promise-Statics.js", - "promise-polyfill.html", - "promise-polyfill-lite.html" - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_web_animations_js", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f8bd760cbdeba131f6790bd5abe170bcbf7b1755ff58ed16d0b82fa8a7f34a7f", - urls = [ - "http://mirror.bazel.build/github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz", - "https://github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz", - ], - strip_prefix = "web-animations-js-2.2.1", - path = "/web-animations-js", - srcs = ["web-animations-next-lite.min.js"], - ) - - web_library_external( - name = "org_polymer_webcomponentsjs", - licenses = ["notice"], # BSD-3-Clause - sha256 = "138c43306ee0a6d699ddca9b3c6b0f4982974ea8b7bdad291ea7276c72301df9", - urls = [ - "http://mirror.bazel.build/github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz", - "https://github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz", - ], - strip_prefix = "webcomponentsjs-0.7.22", - path = "/webcomponentsjs", - srcs = [ - "CustomElements.js", - "CustomElements.min.js", - "HTMLImports.js", - "HTMLImports.min.js", - "MutationObserver.js", - "MutationObserver.min.js", - "ShadowDOM.js", - "ShadowDOM.min.js", - "webcomponents.js", - "webcomponents.min.js", - "webcomponents-lite.js", - "webcomponents-lite.min.js", - ], - ) diff --git a/third_party/pprof.BUILD b/third_party/pprof.BUILD index edd52095949cfdeff5cde3a1c696fe419b01a016..8bd5bacaf12e00101fbabdcd04c40a27d2a900b8 100644 --- a/third_party/pprof.BUILD +++ b/third_party/pprof.BUILD @@ -4,15 +4,15 @@ package( licenses(["notice"]) # MIT -load("@protobuf//:protobuf.bzl", "py_proto_library") +load("@protobuf_archive//:protobuf.bzl", "py_proto_library") exports_files(["pprof/LICENSE"]) py_proto_library( name = "pprof_proto_py", srcs = ["proto/profile.proto"], - default_runtime = "@protobuf//:protobuf_python", - protoc = "@protobuf//:protoc", + default_runtime = "@protobuf_archive//:protobuf_python", + protoc = "@protobuf_archive//:protoc", srcs_version = "PY2AND3", - deps = ["@protobuf//:protobuf_python"], + deps = ["@protobuf_archive//:protobuf_python"], ) diff --git a/third_party/py/BUILD.tpl b/third_party/py/BUILD.tpl index 1ee9c071adb2d9f4aec84b92277c5067f153b666..de06ad5f27e7c08aade4a8f51ab60ba52d012b7b 100644 --- a/third_party/py/BUILD.tpl +++ b/third_party/py/BUILD.tpl @@ -5,7 +5,17 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "python_headers", hdrs = [":python_include"], + data = select({ + ":windows": [":python_import_lib"], + "//conditions:default": [], + }), includes = ["python_include"], + linkopts = select({ + # TODO(pcloudy): Ideally, this should just go into deps after resolving + # https://github.com/bazelbuild/bazel/issues/3237, + ":windows": ["$(locations :python_import_lib)"], + "//conditions:default": [], + }), ) cc_library( @@ -21,5 +31,5 @@ config_setting( ) %{PYTHON_INCLUDE_GENRULE} - %{NUMPY_INCLUDE_GENRULE} +%{PYTHON_IMPORT_LIB_GENRULE} diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl index b4a98af7b6e7742ba99829e6b5e7ce13224cb217..bbc07905fc7f92a26d0aebade66a20209dc3e766 100644 --- a/third_party/py/python_configure.bzl +++ b/third_party/py/python_configure.bzl @@ -9,10 +9,9 @@ * `PYTHON_LIB_PATH`: Location of python libraries. """ -_NUMPY_INCLUDE_PATH = "NUMPY_INCLUDE_PATH" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" -_PYTHON_INCLUDE_PATH = "PYTHON_INCLUDE_PATH" _PYTHON_LIB_PATH = "PYTHON_LIB_PATH" +_TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO" def _tpl(repository_ctx, tpl, substitutions={}, out=None): @@ -116,11 +115,11 @@ def _genrule(src_dir, genrule_name, command, outs): genrule_name + '",\n' + ' outs = [\n' + outs + - ' ],\n' + + '\n ],\n' + ' cmd = """\n' + command + - ' """,\n' + - ')\n\n' + '\n """,\n' + + ')\n' ) @@ -132,15 +131,20 @@ def _norm_path(path): return path -def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name): +def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, + src_files = [], dest_files = []): """Returns a genrule to symlink(or copy if on Windows) a set of files. + + If src_dir is passed, files will be read from the given directory; otherwise + we assume files are in src_files and dest_files """ - src_dir = _norm_path(src_dir) - dest_dir = _norm_path(dest_dir) - files = _read_dir(repository_ctx, src_dir) - # Create a list with the src_dir stripped to use for outputs. - dest_files = files.replace(src_dir, '').splitlines() - src_files = files.splitlines() + if src_dir != None: + src_dir = _norm_path(src_dir) + dest_dir = _norm_path(dest_dir) + files = _read_dir(repository_ctx, src_dir) + # Create a list with the src_dir stripped to use for outputs. + dest_files = files.replace(src_dir, '').splitlines() + src_files = files.splitlines() command = [] outs = [] for i in range(len(dest_files)): @@ -151,12 +155,27 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name): # On Windows, symlink is not supported, so we just copy all the files. cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s' command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest)) - outs.append(' "' + dest_dir + dest_files[i] + '",') + outs.append(' "' + dest_dir + dest_files[i] + '",') genrule = _genrule(src_dir, genrule_name, " && ".join(command), "\n".join(outs)) return genrule +def _get_python_bin(repository_ctx): + """Gets the python bin path.""" + python_bin = _get_env_var(repository_ctx, _PYTHON_BIN_PATH, + None, False) + if python_bin != None: + return python_bin + python_bin_path = repository_ctx.which("python") + if python_bin_path != None: + return str(python_bin_path) + path = _get_env_var(repository_ctx, "PATH") + _python_configure_fail("Cannot find python in PATH, please make sure " + + "python is installed and add its directory in PATH, or set the " + + "environment variable PYTHON_BIN_PATH.\nPATH=%s" % (path)) + + def _get_python_lib(repository_ctx, python_bin): """Gets the python lib path.""" print_lib = ("<> /etc/apt/sources.list.d/armhf.list +# echo 'deb [arch=armhf] http://ports.ubuntu.com/ trusty-updates main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +# echo 'deb [arch=armhf] http://ports.ubuntu.com/ trusty-security main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +# echo 'deb [arch=armhf] http://ports.ubuntu.com/ trusty-backports main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +# sed -i 's#deb http://archive.ubuntu.com/ubuntu/#deb [arch=amd64] http://archive.ubuntu.com/ubuntu/#g' /etc/apt/sources.list +# apt-get update +# apt-get install -y libpython-all-dev:armhf +# +# Make sure you have an up to date version of the Bazel build tool installed too. + +yes '' | ./configure + +if [[ $1 == "PI_ONE" ]]; then + PI_COPTS="--copt=-march=armv6 --copt=-mfpu=vfp" + echo "Building for the Pi One/Zero, with no NEON support" +else + PI_COPTS='--copt=-march=armv7-a --copt=-mfpu=neon-vfpv4 + --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1 + --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2 + --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_8' + echo "Building for the Pi Two/Three, with NEON acceleration" +fi + +bazel build -c opt ${PI_COPTS} \ + --copt=-funsafe-math-optimizations --copt=-ftree-vectorize \ + --copt=-fomit-frame-pointer --cpu=armeabi \ + --crosstool_top=@local_config_arm_compiler//:toolchain \ + --verbose_failures \ + //tensorflow/tools/benchmark:benchmark_model \ + //tensorflow/tools/pip_package:build_pip_package + +TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) +echo "Final outputs will go to ${TMPDIR}" + +# Build a universal wheel. +BDIST_OPTS="--universal" \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package "${TMPDIR}" + +OLD_FN=$(ls "${TMPDIR}" | grep \.whl) +SUB='s/tensorflow-([^-]+)-([^-]+)-.*/tensorflow-\1-\2-none-any.whl/; print' +NEW_FN=$(echo "${OLD_FN}" | perl -ne "${SUB}") +mv "${TMPDIR}/${OLD_FN}" "${TMPDIR}/${NEW_FN}" +cp bazel-bin/tensorflow/tools/benchmark/benchmark_model "${TMPDIR}" + +echo "Output can be found here:" +find "${TMPDIR}" diff --git a/third_party/toolchains/cpus/py/BUILD b/third_party/toolchains/cpus/py/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c175742cbfe918e55035e89b7454596acd43307e --- /dev/null +++ b/third_party/toolchains/cpus/py/BUILD @@ -0,0 +1,197 @@ +# A build file to configure python remote repository used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated BUILD file + +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "python_headers", + hdrs = [":python_include"], + data = select({ + ":windows": [":python_import_lib"], + "//conditions:default": [], + }), + includes = ["python_include"], + linkopts = select({ + # TODO(pcloudy): Ideally, this should just go into deps after resolving + # https://github.com/bazelbuild/bazel/issues/3237, + ":windows": ["$(locations :python_import_lib)"], + "//conditions:default": [], + }), +) + +cc_library( + name = "numpy_headers", + hdrs = [":numpy_include"], + includes = ["numpy_include"], +) + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +genrule( + name = "python_include", + outs = [ + "python_include/code.h", + "python_include/dtoa.h", + "python_include/tupleobject.h", + "python_include/object.h", + "python_include/ast.h", + "python_include/pymacconfig.h", + "python_include/errcode.h", + "python_include/frameobject.h", + "python_include/pgenheaders.h", + "python_include/cellobject.h", + "python_include/intobject.h", + "python_include/pythread.h", + "python_include/cStringIO.h", + "python_include/boolobject.h", + "python_include/modsupport.h", + "python_include/import.h", + "python_include/pymath.h", + "python_include/node.h", + "python_include/funcobject.h", + "python_include/eval.h", + "python_include/longintrepr.h", + "python_include/floatobject.h", + "python_include/rangeobject.h", + "python_include/pyfpe.h", + "python_include/pystrcmp.h", + "python_include/dictobject.h", + "python_include/pyarena.h", + "python_include/objimpl.h", + "python_include/bitset.h", + "python_include/memoryobject.h", + "python_include/bytearrayobject.h", + "python_include/pydebug.h", + "python_include/pyerrors.h", + "python_include/weakrefobject.h", + "python_include/grammar.h", + "python_include/symtable.h", + "python_include/longobject.h", + "python_include/structmember.h", + "python_include/enumobject.h", + "python_include/classobject.h", + "python_include/unicodeobject.h", + "python_include/sliceobject.h", + "python_include/pystrtod.h", + "python_include/genobject.h", + "python_include/pymactoolbox.h", + "python_include/compile.h", + "python_include/pyexpat.h", + "python_include/asdl.h", + "python_include/codecs.h", + "python_include/pyctype.h", + "python_include/sysmodule.h", + "python_include/methodobject.h", + "python_include/graminit.h", + "python_include/cobject.h", + "python_include/intrcheck.h", + "python_include/pyport.h", + "python_include/warnings.h", + "python_include/osdefs.h", + "python_include/fileobject.h", + "python_include/stringobject.h", + "python_include/timefuncs.h", + "python_include/traceback.h", + "python_include/ceval.h", + "python_include/bytes_methods.h", + "python_include/pyconfig.h", + "python_include/Python.h", + "python_include/moduleobject.h", + "python_include/pystate.h", + "python_include/descrobject.h", + "python_include/ucnhash.h", + "python_include/pygetopt.h", + "python_include/pymem.h", + "python_include/complexobject.h", + "python_include/structseq.h", + "python_include/datetime.h", + "python_include/pythonrun.h", + "python_include/numpy/oldnumeric.h", + "python_include/numpy/npy_1_7_deprecated_api.h", + "python_include/numpy/ufunc_api.txt", + "python_include/numpy/multiarray_api.txt", + "python_include/numpy/halffloat.h", + "python_include/numpy/npy_common.h", + "python_include/numpy/utils.h", + "python_include/numpy/npy_interrupt.h", + "python_include/numpy/npy_endian.h", + "python_include/numpy/__ufunc_api.h", + "python_include/numpy/_neighborhood_iterator_imp.h", + "python_include/numpy/ufuncobject.h", + "python_include/numpy/ndarraytypes.h", + "python_include/numpy/npy_math.h", + "python_include/numpy/noprefix.h", + "python_include/numpy/npy_3kcompat.h", + "python_include/numpy/arrayscalars.h", + "python_include/numpy/npy_os.h", + "python_include/numpy/ndarrayobject.h", + "python_include/numpy/npy_no_deprecated_api.h", + "python_include/numpy/arrayobject.h", + "python_include/numpy/_numpyconfig.h", + "python_include/numpy/__multiarray_api.h", + "python_include/numpy/npy_cpu.h", + "python_include/numpy/old_defines.h", + "python_include/numpy/numpyconfig.h", + "python_include/pycapsule.h", + "python_include/setobject.h", + "python_include/listobject.h", + "python_include/bytesobject.h", + "python_include/pgen.h", + "python_include/patchlevel.h", + "python_include/opcode.h", + "python_include/parsetok.h", + "python_include/marshal.h", + "python_include/token.h", + "python_include/iterobject.h", + "python_include/abstract.h", + "python_include/py_curses.h", + "python_include/metagrammar.h", + "python_include/bufferobject.h", + "python_include/Python-ast.h", + ], + cmd = """ +cp "/usr/include/python2.7/code.h" "$(@D)/python_include/code.h" && cp "/usr/include/python2.7/dtoa.h" "$(@D)/python_include/dtoa.h" && cp "/usr/include/python2.7/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp "/usr/include/python2.7/object.h" "$(@D)/python_include/object.h" && cp "/usr/include/python2.7/ast.h" "$(@D)/python_include/ast.h" && cp "/usr/include/python2.7/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp "/usr/include/python2.7/errcode.h" "$(@D)/python_include/errcode.h" && cp "/usr/include/python2.7/frameobject.h" "$(@D)/python_include/frameobject.h" && cp "/usr/include/python2.7/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp "/usr/include/python2.7/cellobject.h" "$(@D)/python_include/cellobject.h" && cp "/usr/include/python2.7/intobject.h" "$(@D)/python_include/intobject.h" && cp "/usr/include/python2.7/pythread.h" "$(@D)/python_include/pythread.h" && cp "/usr/include/python2.7/cStringIO.h" "$(@D)/python_include/cStringIO.h" && cp "/usr/include/python2.7/boolobject.h" "$(@D)/python_include/boolobject.h" && cp "/usr/include/python2.7/modsupport.h" "$(@D)/python_include/modsupport.h" && cp "/usr/include/python2.7/import.h" "$(@D)/python_include/import.h" && cp "/usr/include/python2.7/pymath.h" "$(@D)/python_include/pymath.h" && cp "/usr/include/python2.7/node.h" "$(@D)/python_include/node.h" && cp "/usr/include/python2.7/funcobject.h" "$(@D)/python_include/funcobject.h" && cp "/usr/include/python2.7/eval.h" "$(@D)/python_include/eval.h" && cp "/usr/include/python2.7/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp "/usr/include/python2.7/floatobject.h" "$(@D)/python_include/floatobject.h" && cp "/usr/include/python2.7/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp "/usr/include/python2.7/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp "/usr/include/python2.7/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp "/usr/include/python2.7/dictobject.h" "$(@D)/python_include/dictobject.h" && cp "/usr/include/python2.7/pyarena.h" "$(@D)/python_include/pyarena.h" && cp "/usr/include/python2.7/objimpl.h" "$(@D)/python_include/objimpl.h" && cp "/usr/include/python2.7/bitset.h" "$(@D)/python_include/bitset.h" && cp "/usr/include/python2.7/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp "/usr/include/python2.7/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp "/usr/include/python2.7/pydebug.h" "$(@D)/python_include/pydebug.h" && cp "/usr/include/python2.7/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp "/usr/include/python2.7/weakrefobject.h" "$(@D)/python_include/weakrefobject.h" && cp "/usr/include/python2.7/grammar.h" "$(@D)/python_include/grammar.h" && cp "/usr/include/python2.7/symtable.h" "$(@D)/python_include/symtable.h" && cp "/usr/include/python2.7/longobject.h" "$(@D)/python_include/longobject.h" && cp "/usr/include/python2.7/structmember.h" "$(@D)/python_include/structmember.h" && cp "/usr/include/python2.7/enumobject.h" "$(@D)/python_include/enumobject.h" && cp "/usr/include/python2.7/classobject.h" "$(@D)/python_include/classobject.h" && cp "/usr/include/python2.7/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp "/usr/include/python2.7/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp "/usr/include/python2.7/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp "/usr/include/python2.7/genobject.h" "$(@D)/python_include/genobject.h" && cp "/usr/include/python2.7/pymactoolbox.h" "$(@D)/python_include/pymactoolbox.h" && cp "/usr/include/python2.7/compile.h" "$(@D)/python_include/compile.h" && cp "/usr/include/python2.7/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp "/usr/include/python2.7/asdl.h" "$(@D)/python_include/asdl.h" && cp "/usr/include/python2.7/codecs.h" "$(@D)/python_include/codecs.h" && cp "/usr/include/python2.7/pyctype.h" "$(@D)/python_include/pyctype.h" && cp "/usr/include/python2.7/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp "/usr/include/python2.7/methodobject.h" "$(@D)/python_include/methodobject.h" && cp "/usr/include/python2.7/graminit.h" "$(@D)/python_include/graminit.h" && cp "/usr/include/python2.7/cobject.h" "$(@D)/python_include/cobject.h" && cp "/usr/include/python2.7/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp "/usr/include/python2.7/pyport.h" "$(@D)/python_include/pyport.h" && cp "/usr/include/python2.7/warnings.h" "$(@D)/python_include/warnings.h" && cp "/usr/include/python2.7/osdefs.h" "$(@D)/python_include/osdefs.h" && cp "/usr/include/python2.7/fileobject.h" "$(@D)/python_include/fileobject.h" && cp "/usr/include/python2.7/stringobject.h" "$(@D)/python_include/stringobject.h" && cp "/usr/include/python2.7/timefuncs.h" "$(@D)/python_include/timefuncs.h" && cp "/usr/include/python2.7/traceback.h" "$(@D)/python_include/traceback.h" && cp "/usr/include/python2.7/ceval.h" "$(@D)/python_include/ceval.h" && cp "/usr/include/python2.7/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp "/usr/include/python2.7/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp "/usr/include/python2.7/Python.h" "$(@D)/python_include/Python.h" && cp "/usr/include/python2.7/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp "/usr/include/python2.7/pystate.h" "$(@D)/python_include/pystate.h" && cp "/usr/include/python2.7/descrobject.h" "$(@D)/python_include/descrobject.h" && cp "/usr/include/python2.7/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp "/usr/include/python2.7/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp "/usr/include/python2.7/pymem.h" "$(@D)/python_include/pymem.h" && cp "/usr/include/python2.7/complexobject.h" "$(@D)/python_include/complexobject.h" && cp "/usr/include/python2.7/structseq.h" "$(@D)/python_include/structseq.h" && cp "/usr/include/python2.7/datetime.h" "$(@D)/python_include/datetime.h" && cp "/usr/include/python2.7/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp "/usr/include/python2.7/numpy/oldnumeric.h" "$(@D)/python_include/numpy/oldnumeric.h" && cp "/usr/include/python2.7/numpy/npy_1_7_deprecated_api.h" "$(@D)/python_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/include/python2.7/numpy/ufunc_api.txt" "$(@D)/python_include/numpy/ufunc_api.txt" && cp "/usr/include/python2.7/numpy/multiarray_api.txt" "$(@D)/python_include/numpy/multiarray_api.txt" && cp "/usr/include/python2.7/numpy/halffloat.h" "$(@D)/python_include/numpy/halffloat.h" && cp "/usr/include/python2.7/numpy/npy_common.h" "$(@D)/python_include/numpy/npy_common.h" && cp "/usr/include/python2.7/numpy/utils.h" "$(@D)/python_include/numpy/utils.h" && cp "/usr/include/python2.7/numpy/npy_interrupt.h" "$(@D)/python_include/numpy/npy_interrupt.h" && cp "/usr/include/python2.7/numpy/npy_endian.h" "$(@D)/python_include/numpy/npy_endian.h" && cp "/usr/include/python2.7/numpy/__ufunc_api.h" "$(@D)/python_include/numpy/__ufunc_api.h" && cp "/usr/include/python2.7/numpy/_neighborhood_iterator_imp.h" "$(@D)/python_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/include/python2.7/numpy/ufuncobject.h" "$(@D)/python_include/numpy/ufuncobject.h" && cp "/usr/include/python2.7/numpy/ndarraytypes.h" "$(@D)/python_include/numpy/ndarraytypes.h" && cp "/usr/include/python2.7/numpy/npy_math.h" "$(@D)/python_include/numpy/npy_math.h" && cp "/usr/include/python2.7/numpy/noprefix.h" "$(@D)/python_include/numpy/noprefix.h" && cp "/usr/include/python2.7/numpy/npy_3kcompat.h" "$(@D)/python_include/numpy/npy_3kcompat.h" && cp "/usr/include/python2.7/numpy/arrayscalars.h" "$(@D)/python_include/numpy/arrayscalars.h" && cp "/usr/include/python2.7/numpy/npy_os.h" "$(@D)/python_include/numpy/npy_os.h" && cp "/usr/include/python2.7/numpy/ndarrayobject.h" "$(@D)/python_include/numpy/ndarrayobject.h" && cp "/usr/include/python2.7/numpy/npy_no_deprecated_api.h" "$(@D)/python_include/numpy/npy_no_deprecated_api.h" && cp "/usr/include/python2.7/numpy/arrayobject.h" "$(@D)/python_include/numpy/arrayobject.h" && cp "/usr/include/python2.7/numpy/_numpyconfig.h" "$(@D)/python_include/numpy/_numpyconfig.h" && cp "/usr/include/python2.7/numpy/__multiarray_api.h" "$(@D)/python_include/numpy/__multiarray_api.h" && cp "/usr/include/python2.7/numpy/npy_cpu.h" "$(@D)/python_include/numpy/npy_cpu.h" && cp "/usr/include/python2.7/numpy/old_defines.h" "$(@D)/python_include/numpy/old_defines.h" && cp "/usr/include/python2.7/numpy/numpyconfig.h" "$(@D)/python_include/numpy/numpyconfig.h" && cp "/usr/include/python2.7/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp "/usr/include/python2.7/setobject.h" "$(@D)/python_include/setobject.h" && cp "/usr/include/python2.7/listobject.h" "$(@D)/python_include/listobject.h" && cp "/usr/include/python2.7/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp "/usr/include/python2.7/pgen.h" "$(@D)/python_include/pgen.h" && cp "/usr/include/python2.7/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp "/usr/include/python2.7/opcode.h" "$(@D)/python_include/opcode.h" && cp "/usr/include/python2.7/parsetok.h" "$(@D)/python_include/parsetok.h" && cp "/usr/include/python2.7/marshal.h" "$(@D)/python_include/marshal.h" && cp "/usr/include/python2.7/token.h" "$(@D)/python_include/token.h" && cp "/usr/include/python2.7/iterobject.h" "$(@D)/python_include/iterobject.h" && cp "/usr/include/python2.7/abstract.h" "$(@D)/python_include/abstract.h" && cp "/usr/include/python2.7/py_curses.h" "$(@D)/python_include/py_curses.h" && cp "/usr/include/python2.7/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp "/usr/include/python2.7/bufferobject.h" "$(@D)/python_include/bufferobject.h" && cp "/usr/include/python2.7/Python-ast.h" "$(@D)/python_include/Python-ast.h" + """, +) + +genrule( + name = "numpy_include", + outs = [ + "numpy_include/numpy/oldnumeric.h", + "numpy_include/numpy/npy_1_7_deprecated_api.h", + "numpy_include/numpy/ufunc_api.txt", + "numpy_include/numpy/multiarray_api.txt", + "numpy_include/numpy/halffloat.h", + "numpy_include/numpy/npy_common.h", + "numpy_include/numpy/utils.h", + "numpy_include/numpy/npy_interrupt.h", + "numpy_include/numpy/npy_endian.h", + "numpy_include/numpy/__ufunc_api.h", + "numpy_include/numpy/_neighborhood_iterator_imp.h", + "numpy_include/numpy/ufuncobject.h", + "numpy_include/numpy/ndarraytypes.h", + "numpy_include/numpy/npy_math.h", + "numpy_include/numpy/noprefix.h", + "numpy_include/numpy/npy_3kcompat.h", + "numpy_include/numpy/arrayscalars.h", + "numpy_include/numpy/npy_os.h", + "numpy_include/numpy/ndarrayobject.h", + "numpy_include/numpy/npy_no_deprecated_api.h", + "numpy_include/numpy/arrayobject.h", + "numpy_include/numpy/_numpyconfig.h", + "numpy_include/numpy/__multiarray_api.h", + "numpy_include/numpy/npy_cpu.h", + "numpy_include/numpy/old_defines.h", + "numpy_include/numpy/numpyconfig.h", + ], + cmd = """ +cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h" + """, +) diff --git a/third_party/toolchains/gpus/crosstool/BUILD b/third_party/toolchains/gpus/crosstool/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a8c6b0f0291363f3a7576a70e78b3428fb984957 --- /dev/null +++ b/third_party/toolchains/gpus/crosstool/BUILD @@ -0,0 +1,52 @@ +# A build file to configure cc toolchain for GPU build used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated BUILD file + +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "local|compiler": ":cc-compiler-local", + "darwin|compiler": ":cc-compiler-darwin", + }, +) + +cc_toolchain( + name = "cc-compiler-local", + all_files = ":empty", + compiler_files = ":empty", + cpu = "local", + dwp_files = ":empty", + dynamic_runtime_libs = [":empty"], + linker_files = ":empty", + objcopy_files = ":empty", + static_runtime_libs = [":empty"], + strip_files = ":empty", + # To support linker flags that need to go to the start of command line + # we need the toolchain to support parameter files. Parameter files are + # last on the command line and contain all shared libraries to link, so all + # regular options will be left of them. + supports_param_files = 1, +) + +cc_toolchain( + name = "cc-compiler-darwin", + all_files = ":empty", + compiler_files = ":empty", + cpu = "darwin", + dwp_files = ":empty", + dynamic_runtime_libs = [":empty"], + linker_files = ":empty", + objcopy_files = ":empty", + static_runtime_libs = [":empty"], + strip_files = ":empty", + supports_param_files = 0, +) + +filegroup( + name = "empty", + srcs = [], +) diff --git a/third_party/toolchains/gpus/crosstool/CROSSTOOL b/third_party/toolchains/gpus/crosstool/CROSSTOOL new file mode 100644 index 0000000000000000000000000000000000000000..224b8912f6d743ad78b0ce835fdb8aa30e5e1309 --- /dev/null +++ b/third_party/toolchains/gpus/crosstool/CROSSTOOL @@ -0,0 +1,302 @@ +# A crosstool configuration for GPU build used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated file + +major_version: "local" +minor_version: "" +default_target_cpu: "same_as_host" + +default_toolchain { + cpu: "k8" + toolchain_identifier: "local_linux" +} +default_toolchain { + cpu: "piii" + toolchain_identifier: "local_linux" +} +default_toolchain { + cpu: "arm" + toolchain_identifier: "local_linux" +} +default_toolchain { + cpu: "darwin" + toolchain_identifier: "local_darwin" +} +default_toolchain { + cpu: "ppc" + toolchain_identifier: "local_linux" +} + +toolchain { + abi_version: "local" + abi_libc_version: "local" + compiler: "compiler" + host_system_name: "local" + needsPic: true + target_libc: "local" + target_cpu: "local" + target_system_name: "local" + toolchain_identifier: "local_linux" + + feature { + name: "c++11" + flag_set { + action: "c++-compile" + flag_group { + flag: "-std=c++11" + } + } + } + + feature { + name: "stdlib" + flag_set { + action: "c++-link-executable" + action: "c++-link-dynamic-library" + flag_group { + flag: "-lstdc++" + } + } + } + + feature { + name: "determinism" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # Make C++ compilation deterministic. Use linkstamping instead of these + # compiler symbols. + flag: "-Wno-builtin-macro-redefined" + flag: "-D__DATE__=\"redacted\"" + flag: "-D__TIMESTAMP__=\"redacted\"" + flag: "-D__TIME__=\"redacted\"" + } + } + } + + feature { + name: "alwayslink" + flag_set { + action: "c++-link-dynamic-library" + action: "c++-link-executable" + flag_group { + flag: "-Wl,-no-as-needed" + } + } + } + + # This feature will be enabled for builds that support pic by bazel. + feature { + name: "pic" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + expand_if_all_available: "pic" + flag: "-fPIC" + } + flag_group { + expand_if_none_available: "pic" + flag: "-fPIE" + } + } + } + + # Security hardening on by default. + feature { + name: "hardening" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases. + # We need to undef it before redefining it as some distributions now + # have it enabled by default. + flag: "-U_FORTIFY_SOURCE" + flag: "-D_FORTIFY_SOURCE=1" + flag: "-fstack-protector" + } + } + flag_set { + action: "c++-link-dynamic-library" + flag_group { + flag: "-Wl,-z,relro,-z,now" + } + } + flag_set { + action: "c++-link-executable" + flag_group { + flag: "-pie" + flag: "-Wl,-z,relro,-z,now" + } + } + } + + feature { + name: "warnings" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # All warnings are enabled. Maybe enable -Werror as well? + flag: "-Wall" + # TODO(ngiraldo): Some parts of the codebase set -Werror and hit this + # warning, so switch it off for now. + flag: "-Wno-invalid-partial-specialization" + } + } + } + + # Keep stack frames for debugging, even in opt mode. + feature { + name: "frame-pointer" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-fno-omit-frame-pointer" + } + } + } + + feature { + name: "build-id" + flag_set { + action: "c++-link-executable" + action: "c++-link-dynamic-library" + flag_group { + # Stamp the binary with a unique identifier. + flag: "-Wl,--build-id=md5" + flag: "-Wl,--hash-style=gnu" + } + } + } + + feature { + name: "no-canonical-prefixes" + flag_set { + action: "c-compile" + action: "c++-compile" + action: "c++-link-executable" + action: "c++-link-dynamic-library" + flag_group { + flag:"-no-canonical-prefixes" + } + } + } + + feature { + name: "disable-assertions" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-DNDEBUG" + } + } + } + + feature { + name: "linker-bin-path" + + flag_set { + action: "c++-link-executable" + action: "c++-link-dynamic-library" + flag_group { + flag: "-B/usr/bin/" + } + } + } + + feature { + name: "common" + implies: "stdlib" + implies: "c++11" + implies: "determinism" + implies: "alwayslink" + implies: "hardening" + implies: "warnings" + implies: "frame-pointer" + implies: "build-id" + implies: "no-canonical-prefixes" + implies: "linker-bin-path" + } + + feature { + name: "opt" + implies: "common" + implies: "disable-assertions" + + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # No debug symbols. + # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt + # or even generally? However, that can't happen here, as it requires + # special handling in Bazel. + flag: "-g0" + + # Conservative choice for -O + # -O3 can increase binary size and even slow down the resulting binaries. + # Profile first and / or use FDO if you need better performance than this. + flag: "-O2" + + # Removal of unused code and data at link time (can this increase binary size in some cases?). + flag: "-ffunction-sections" + flag: "-fdata-sections" + } + } + flag_set { + action: "c++-link-dynamic-library" + action: "c++-link-executable" + flag_group { + flag: "-Wl,--gc-sections" + } + } + } + + feature { + name: "fastbuild" + implies: "common" + } + + feature { + name: "dbg" + implies: "common" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-g" + } + } + } + + # Set clang as a C/C++ compiler. + tool_path { name: "gcc" path: "/usr/local/bin/clang" } + + # Use the default system toolchain for everything else. + tool_path { name: "ar" path: "/usr/bin/ar" } + tool_path { name: "compat-ld" path: "/usr/bin/ld" } + tool_path { name: "cpp" path: "/usr/bin/cpp" } + tool_path { name: "dwp" path: "/usr/bin/dwp" } + tool_path { name: "gcov" path: "/usr/bin/gcov" } + tool_path { name: "ld" path: "/usr/bin/ld" } + tool_path { name: "nm" path: "/usr/bin/nm" } + tool_path { name: "objcopy" path: "/usr/bin/objcopy" } + tool_path { name: "objdump" path: "/usr/bin/objdump" } + tool_path { name: "strip" path: "/usr/bin/strip" } + + # Enabled dynamic linking. + linking_mode_flags { mode: DYNAMIC } + + cxx_builtin_include_directory: "/usr/include/c++/5.4.0" + cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu/c++/5.4.0" + cxx_builtin_include_directory: "/usr/include/c++/5.4.0/backward" + cxx_builtin_include_directory: "/usr/local/include" + cxx_builtin_include_directory: "/usr/local/lib/clang/5.0.0/include" + cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu" + cxx_builtin_include_directory: "/usr/include" +} diff --git a/third_party/toolchains/gpus/cuda/BUILD b/third_party/toolchains/gpus/cuda/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..36be86cd1021188eccf2f8d16e17c97531a9e09a --- /dev/null +++ b/third_party/toolchains/gpus/cuda/BUILD @@ -0,0 +1,1362 @@ +# A build file to configure cuda remote repository used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated BUILD file + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "using_nvcc", + values = { + "define": "using_cuda_nvcc=true", + }, +) + +config_setting( + name = "using_clang", + values = { + "define": "using_cuda_clang=true", + }, +) + +# Equivalent to using_clang && -c opt. +config_setting( + name = "using_clang_opt", + values = { + "define": "using_cuda_clang=true", + "compilation_mode": "opt", + }, +) + +config_setting( + name = "darwin", + values = {"cpu": "darwin"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "freebsd", + values = {"cpu": "freebsd"}, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ":cuda-include", + ":cudnn-include", + ], + includes = [ + ".", + "cuda/include", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart_static", + srcs = ["cuda/lib/libcudart_static.a"], + includes = [ + ".", + "cuda/include", + ], + linkopts = select({ + ":freebsd": [], + "//conditions:default": ["-ldl"], + }) + [ + "-lpthread", + "-lrt", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cuda_driver", + srcs = ["cuda/lib/libcuda.so"], + includes = [ + ".", + "cuda/include", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart", + srcs = ["cuda/lib/libcudart.so.8.0"], + data = ["cuda/lib/libcudart.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cublas", + srcs = ["cuda/lib/libcublas.so.8.0"], + data = ["cuda/lib/libcublas.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cusolver", + srcs = ["cuda/lib/libcusolver.so.8.0"], + data = ["cuda/lib/libcusolver.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkopts = ["-lgomp"], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudnn", + srcs = ["cuda/lib/libcudnn.so.6"], + data = ["cuda/lib/libcudnn.so.6"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cufft", + srcs = ["cuda/lib/libcufft.so.8.0"], + data = ["cuda/lib/libcufft.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "curand", + srcs = ["cuda/lib/libcurand.so.8.0"], + data = ["cuda/lib/libcurand.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cuda", + visibility = ["//visibility:public"], + deps = [ + ":cublas", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +cc_library( + name = "cupti_headers", + hdrs = [ + "cuda/cuda_config.h", + ":cuda-extras", + ], + includes = [ + ".", + "cuda/extras/CUPTI/include/", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cupti_dsos", + data = ["cuda/lib/libcupti.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], + visibility = ["//visibility:public"], +) + +genrule( + name = "cuda-include", + outs = [ + "cuda/include/math_functions.hpp", + "cuda/include/cufft.h", + "cuda/include/nvgraph.h", + "cuda/include/curand_normal.h", + "cuda/include/curand_uniform.h", + "cuda/include/nppi_data_exchange_and_initialization.h", + "cuda/include/cuda_gl_interop.h", + "cuda/include/nppi_compression_functions.h", + "cuda/include/npp.h", + "cuda/include/cuda.h", + "cuda/include/nppi_statistics_functions.h", + "cuda/include/vector_functions.hpp", + "cuda/include/sm_32_intrinsics.hpp", + "cuda/include/sm_32_intrinsics.h", + "cuda/include/curand_discrete.h", + "cuda/include/cuda_runtime.h", + "cuda/include/cufftXt.h", + "cuda/include/sm_61_intrinsics.h", + "cuda/include/texture_fetch_functions.h", + "cuda/include/curand_mrg32k3a.h", + "cuda/include/host_defines.h", + "cuda/include/common_functions.h", + "cuda/include/nppi_support_functions.h", + "cuda/include/nppi_linear_transforms.h", + "cuda/include/device_double_functions.hpp", + "cuda/include/math_constants.h", + "cuda/include/nvToolsExtSync.h", + "cuda/include/npps_initialization.h", + "cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h", + "cuda/include/texture_indirect_functions.hpp", + "cuda/include/cudaProfiler.h", + "cuda/include/npps_filtering_functions.h", + "cuda/include/cusparse_v2.h", + "cuda/include/nppi.h", + "cuda/include/surface_indirect_functions.h", + "cuda/include/sm_30_intrinsics.h", + "cuda/include/device_double_functions.h", + "cuda/include/sm_35_intrinsics.h", + "cuda/include/cusolverSp.h", + "cuda/include/library_types.h", + "cuda/include/surface_indirect_functions.hpp", + "cuda/include/cudalibxt.h", + "cuda/include/channel_descriptor.h", + "cuda/include/device_functions_decls.h", + "cuda/include/curand_kernel.h", + "cuda/include/curand_mtgp32_host.h", + "cuda/include/nvToolsExtCuda.h", + "cuda/include/nvToolsExt.h", + "cuda/include/cuComplex.h", + "cuda/include/sm_32_atomic_functions.h", + "cuda/include/texture_indirect_functions.h", + "cuda/include/sm_32_atomic_functions.hpp", + "cuda/include/sm_20_intrinsics.hpp", + "cuda/include/device_launch_parameters.h", + "cuda/include/curand_mtgp32.h", + "cuda/include/texture_fetch_functions.hpp", + "cuda/include/cuda_occupancy.h", + "cuda/include/CL/opencl.h", + "cuda/include/CL/cl_platform.h", + "cuda/include/CL/cl_egl.h", + "cuda/include/CL/cl_gl.h", + "cuda/include/CL/cl.h", + "cuda/include/CL/cl_gl_ext.h", + "cuda/include/CL/cl_ext.h", + "cuda/include/CL/cl.hpp", + "cuda/include/host_config.h", + "cuda/include/cuda_surface_types.h", + "cuda/include/math_functions.h", + "cuda/include/nvToolsExtMeta.h", + "cuda/include/sm_20_atomic_functions.hpp", + "cuda/include/device_functions.h", + "cuda/include/device_types.h", + "cuda/include/npps_conversion_functions.h", + "cuda/include/curand_precalc.h", + "cuda/include/cusolverRf.h", + "cuda/include/sm_60_atomic_functions.hpp", + "cuda/include/cuviddec.h", + "cuda/include/curand_discrete2.h", + "cuda/include/device_functions.hpp", + "cuda/include/thrust/transform_scan.h", + "cuda/include/thrust/system_error.h", + "cuda/include/thrust/device_malloc.h", + "cuda/include/thrust/partition.h", + "cuda/include/thrust/unique.h", + "cuda/include/thrust/device_delete.h", + "cuda/include/thrust/execution_policy.h", + "cuda/include/thrust/adjacent_difference.h", + "cuda/include/thrust/sequence.h", + "cuda/include/thrust/merge.h", + "cuda/include/thrust/device_new.h", + "cuda/include/thrust/transform_reduce.h", + "cuda/include/thrust/device_vector.h", + "cuda/include/thrust/gather.h", + "cuda/include/thrust/sort.h", + "cuda/include/thrust/scan.h", + "cuda/include/thrust/detail/temporary_array.h", + "cuda/include/thrust/detail/util/align.h", + "cuda/include/thrust/detail/util/blocking.h", + "cuda/include/thrust/detail/transform.inl", + "cuda/include/thrust/detail/device_vector.inl", + "cuda/include/thrust/detail/binary_search.inl", + "cuda/include/thrust/detail/overlapped_copy.h", + "cuda/include/thrust/detail/vector_base.inl", + "cuda/include/thrust/detail/device_reference.inl", + "cuda/include/thrust/detail/functional/actor.h", + "cuda/include/thrust/detail/functional/value.h", + "cuda/include/thrust/detail/functional/operators.h", + "cuda/include/thrust/detail/functional/operators/logical_operators.h", + "cuda/include/thrust/detail/functional/operators/relational_operators.h", + "cuda/include/thrust/detail/functional/operators/assignment_operator.h", + "cuda/include/thrust/detail/functional/operators/bitwise_operators.h", + "cuda/include/thrust/detail/functional/operators/operator_adaptors.h", + "cuda/include/thrust/detail/functional/operators/arithmetic_operators.h", + "cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h", + "cuda/include/thrust/detail/functional/argument.h", + "cuda/include/thrust/detail/functional/placeholder.h", + "cuda/include/thrust/detail/functional/actor.inl", + "cuda/include/thrust/detail/functional/composite.h", + "cuda/include/thrust/detail/static_map.h", + "cuda/include/thrust/detail/type_traits/has_nested_type.h", + "cuda/include/thrust/detail/type_traits/is_call_possible.h", + "cuda/include/thrust/detail/type_traits/function_traits.h", + "cuda/include/thrust/detail/type_traits/pointer_traits.h", + "cuda/include/thrust/detail/type_traits/has_member_function.h", + "cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h", + "cuda/include/thrust/detail/type_traits/minimum_type.h", + "cuda/include/thrust/detail/type_traits/has_trivial_assign.h", + "cuda/include/thrust/detail/type_traits/is_metafunction_defined.h", + "cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h", + "cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h", + "cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h", + "cuda/include/thrust/detail/reference.h", + "cuda/include/thrust/detail/inner_product.inl", + "cuda/include/thrust/detail/use_default.h", + "cuda/include/thrust/detail/sequence.inl", + "cuda/include/thrust/detail/sort.inl", + "cuda/include/thrust/detail/equal.inl", + "cuda/include/thrust/detail/execution_policy.h", + "cuda/include/thrust/detail/integer_traits.h", + "cuda/include/thrust/detail/type_traits.h", + "cuda/include/thrust/detail/reverse.inl", + "cuda/include/thrust/detail/tabulate.inl", + "cuda/include/thrust/detail/unique.inl", + "cuda/include/thrust/detail/scatter.inl", + "cuda/include/thrust/detail/set_operations.inl", + "cuda/include/thrust/detail/device_malloc.inl", + "cuda/include/thrust/detail/copy_if.inl", + "cuda/include/thrust/detail/fill.inl", + "cuda/include/thrust/detail/temporary_array.inl", + "cuda/include/thrust/detail/transform_scan.inl", + "cuda/include/thrust/detail/minmax.h", + "cuda/include/thrust/detail/swap.inl", + "cuda/include/thrust/detail/pointer.inl", + "cuda/include/thrust/detail/transform_reduce.inl", + "cuda/include/thrust/detail/config.h", + "cuda/include/thrust/detail/distance.inl", + "cuda/include/thrust/detail/pair.inl", + "cuda/include/thrust/detail/allocator/temporary_allocator.h", + "cuda/include/thrust/detail/allocator/tagged_allocator.h", + "cuda/include/thrust/detail/allocator/destroy_range.inl", + "cuda/include/thrust/detail/allocator/destroy_range.h", + "cuda/include/thrust/detail/allocator/no_throw_allocator.h", + "cuda/include/thrust/detail/allocator/default_construct_range.inl", + "cuda/include/thrust/detail/allocator/fill_construct_range.inl", + "cuda/include/thrust/detail/allocator/tagged_allocator.inl", + "cuda/include/thrust/detail/allocator/malloc_allocator.h", + "cuda/include/thrust/detail/allocator/allocator_traits.h", + "cuda/include/thrust/detail/allocator/copy_construct_range.h", + "cuda/include/thrust/detail/allocator/allocator_traits.inl", + "cuda/include/thrust/detail/allocator/default_construct_range.h", + "cuda/include/thrust/detail/allocator/copy_construct_range.inl", + "cuda/include/thrust/detail/allocator/malloc_allocator.inl", + "cuda/include/thrust/detail/allocator/temporary_allocator.inl", + "cuda/include/thrust/detail/allocator/fill_construct_range.h", + "cuda/include/thrust/detail/temporary_buffer.h", + "cuda/include/thrust/detail/reduce.inl", + "cuda/include/thrust/detail/device_new.inl", + "cuda/include/thrust/detail/pointer.h", + "cuda/include/thrust/detail/for_each.inl", + "cuda/include/thrust/detail/generate.inl", + "cuda/include/thrust/detail/dispatch/is_trivial_copy.h", + "cuda/include/thrust/detail/adjacent_difference.inl", + "cuda/include/thrust/detail/tuple_meta_transform.h", + "cuda/include/thrust/detail/functional.inl", + "cuda/include/thrust/detail/remove.inl", + "cuda/include/thrust/detail/tuple_transform.h", + "cuda/include/thrust/detail/merge.inl", + "cuda/include/thrust/detail/extrema.inl", + "cuda/include/thrust/detail/trivial_sequence.h", + "cuda/include/thrust/detail/vector_base.h", + "cuda/include/thrust/detail/count.inl", + "cuda/include/thrust/detail/uninitialized_copy.inl", + "cuda/include/thrust/detail/function.h", + "cuda/include/thrust/detail/swap_ranges.inl", + "cuda/include/thrust/detail/device_delete.inl", + "cuda/include/thrust/detail/static_assert.h", + "cuda/include/thrust/detail/logical.inl", + "cuda/include/thrust/detail/seq.h", + "cuda/include/thrust/detail/mpl/math.h", + "cuda/include/thrust/detail/mismatch.inl", + "cuda/include/thrust/detail/internal_functional.h", + "cuda/include/thrust/detail/get_iterator_value.h", + "cuda/include/thrust/detail/copy.inl", + "cuda/include/thrust/detail/copy.h", + "cuda/include/thrust/detail/complex/catrigf.h", + "cuda/include/thrust/detail/complex/cpowf.h", + "cuda/include/thrust/detail/complex/csqrtf.h", + "cuda/include/thrust/detail/complex/ccoshf.h", + "cuda/include/thrust/detail/complex/csinhf.h", + "cuda/include/thrust/detail/complex/clogf.h", + "cuda/include/thrust/detail/complex/ccosh.h", + "cuda/include/thrust/detail/complex/arithmetic.h", + "cuda/include/thrust/detail/complex/csqrt.h", + "cuda/include/thrust/detail/complex/cpow.h", + "cuda/include/thrust/detail/complex/complex.inl", + "cuda/include/thrust/detail/complex/math_private.h", + "cuda/include/thrust/detail/complex/c99math.h", + "cuda/include/thrust/detail/complex/cproj.h", + "cuda/include/thrust/detail/complex/catrig.h", + "cuda/include/thrust/detail/complex/ctanhf.h", + "cuda/include/thrust/detail/complex/cexpf.h", + "cuda/include/thrust/detail/complex/csinh.h", + "cuda/include/thrust/detail/complex/stream.h", + "cuda/include/thrust/detail/complex/ctanh.h", + "cuda/include/thrust/detail/complex/cexp.h", + "cuda/include/thrust/detail/complex/clog.h", + "cuda/include/thrust/detail/range/head_flags.h", + "cuda/include/thrust/detail/range/tail_flags.h", + "cuda/include/thrust/detail/execute_with_allocator.h", + "cuda/include/thrust/detail/integer_math.h", + "cuda/include/thrust/detail/swap.h", + "cuda/include/thrust/detail/uninitialized_fill.inl", + "cuda/include/thrust/detail/scan.inl", + "cuda/include/thrust/detail/gather.inl", + "cuda/include/thrust/detail/reference_forward_declaration.h", + "cuda/include/thrust/detail/numeric_traits.h", + "cuda/include/thrust/detail/reference.inl", + "cuda/include/thrust/detail/cstdint.h", + "cuda/include/thrust/detail/device_free.inl", + "cuda/include/thrust/detail/copy_if.h", + "cuda/include/thrust/detail/partition.inl", + "cuda/include/thrust/detail/find.inl", + "cuda/include/thrust/detail/config/forceinline.h", + "cuda/include/thrust/detail/config/debug.h", + "cuda/include/thrust/detail/config/config.h", + "cuda/include/thrust/detail/config/host_device.h", + "cuda/include/thrust/detail/config/host_system.h", + "cuda/include/thrust/detail/config/compiler.h", + "cuda/include/thrust/detail/config/device_system.h", + "cuda/include/thrust/detail/config/compiler_fence.h", + "cuda/include/thrust/detail/config/exec_check_disable.h", + "cuda/include/thrust/detail/config/simple_defines.h", + "cuda/include/thrust/detail/config/global_workarounds.h", + "cuda/include/thrust/detail/replace.inl", + "cuda/include/thrust/detail/device_ptr.inl", + "cuda/include/thrust/detail/tuple.inl", + "cuda/include/thrust/detail/malloc_and_free.h", + "cuda/include/thrust/detail/host_vector.inl", + "cuda/include/thrust/detail/raw_pointer_cast.h", + "cuda/include/thrust/detail/advance.inl", + "cuda/include/thrust/detail/contiguous_storage.h", + "cuda/include/thrust/detail/raw_reference_cast.h", + "cuda/include/thrust/detail/contiguous_storage.inl", + "cuda/include/thrust/reverse.h", + "cuda/include/thrust/device_malloc_allocator.h", + "cuda/include/thrust/scatter.h", + "cuda/include/thrust/pair.h", + "cuda/include/thrust/advance.h", + "cuda/include/thrust/find.h", + "cuda/include/thrust/device_ptr.h", + "cuda/include/thrust/generate.h", + "cuda/include/thrust/uninitialized_fill.h", + "cuda/include/thrust/system/system_error.h", + "cuda/include/thrust/system/detail/bad_alloc.h", + "cuda/include/thrust/system/detail/adl/transform_scan.h", + "cuda/include/thrust/system/detail/adl/unique_by_key.h", + "cuda/include/thrust/system/detail/adl/partition.h", + "cuda/include/thrust/system/detail/adl/unique.h", + "cuda/include/thrust/system/detail/adl/adjacent_difference.h", + "cuda/include/thrust/system/detail/adl/sequence.h", + "cuda/include/thrust/system/detail/adl/merge.h", + "cuda/include/thrust/system/detail/adl/transform_reduce.h", + "cuda/include/thrust/system/detail/adl/gather.h", + "cuda/include/thrust/system/detail/adl/sort.h", + "cuda/include/thrust/system/detail/adl/scan.h", + "cuda/include/thrust/system/detail/adl/temporary_buffer.h", + "cuda/include/thrust/system/detail/adl/scan_by_key.h", + "cuda/include/thrust/system/detail/adl/reverse.h", + "cuda/include/thrust/system/detail/adl/assign_value.h", + "cuda/include/thrust/system/detail/adl/scatter.h", + "cuda/include/thrust/system/detail/adl/find.h", + "cuda/include/thrust/system/detail/adl/generate.h", + "cuda/include/thrust/system/detail/adl/uninitialized_fill.h", + "cuda/include/thrust/system/detail/adl/remove.h", + "cuda/include/thrust/system/detail/adl/tabulate.h", + "cuda/include/thrust/system/detail/adl/for_each.h", + "cuda/include/thrust/system/detail/adl/reduce_by_key.h", + "cuda/include/thrust/system/detail/adl/reduce.h", + "cuda/include/thrust/system/detail/adl/equal.h", + "cuda/include/thrust/system/detail/adl/copy.h", + "cuda/include/thrust/system/detail/adl/swap_ranges.h", + "cuda/include/thrust/system/detail/adl/uninitialized_copy.h", + "cuda/include/thrust/system/detail/adl/binary_search.h", + "cuda/include/thrust/system/detail/adl/set_operations.h", + "cuda/include/thrust/system/detail/adl/mismatch.h", + "cuda/include/thrust/system/detail/adl/extrema.h", + "cuda/include/thrust/system/detail/adl/count.h", + "cuda/include/thrust/system/detail/adl/replace.h", + "cuda/include/thrust/system/detail/adl/get_value.h", + "cuda/include/thrust/system/detail/adl/inner_product.h", + "cuda/include/thrust/system/detail/adl/copy_if.h", + "cuda/include/thrust/system/detail/adl/logical.h", + "cuda/include/thrust/system/detail/adl/iter_swap.h", + "cuda/include/thrust/system/detail/adl/malloc_and_free.h", + "cuda/include/thrust/system/detail/adl/fill.h", + "cuda/include/thrust/system/detail/adl/transform.h", + "cuda/include/thrust/system/detail/errno.h", + "cuda/include/thrust/system/detail/error_category.inl", + "cuda/include/thrust/system/detail/sequential/transform_scan.h", + "cuda/include/thrust/system/detail/sequential/unique_by_key.h", + "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h", + "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl", + "cuda/include/thrust/system/detail/sequential/stable_merge_sort.h", + "cuda/include/thrust/system/detail/sequential/sort.inl", + "cuda/include/thrust/system/detail/sequential/partition.h", + "cuda/include/thrust/system/detail/sequential/unique.h", + "cuda/include/thrust/system/detail/sequential/execution_policy.h", + "cuda/include/thrust/system/detail/sequential/adjacent_difference.h", + "cuda/include/thrust/system/detail/sequential/sequence.h", + "cuda/include/thrust/system/detail/sequential/merge.h", + "cuda/include/thrust/system/detail/sequential/transform_reduce.h", + "cuda/include/thrust/system/detail/sequential/gather.h", + "cuda/include/thrust/system/detail/sequential/sort.h", + "cuda/include/thrust/system/detail/sequential/copy_backward.h", + "cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl", + "cuda/include/thrust/system/detail/sequential/scan.h", + "cuda/include/thrust/system/detail/sequential/temporary_buffer.h", + "cuda/include/thrust/system/detail/sequential/scan_by_key.h", + "cuda/include/thrust/system/detail/sequential/reverse.h", + "cuda/include/thrust/system/detail/sequential/assign_value.h", + "cuda/include/thrust/system/detail/sequential/scatter.h", + "cuda/include/thrust/system/detail/sequential/find.h", + "cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl", + "cuda/include/thrust/system/detail/sequential/merge.inl", + "cuda/include/thrust/system/detail/sequential/generate.h", + "cuda/include/thrust/system/detail/sequential/uninitialized_fill.h", + "cuda/include/thrust/system/detail/sequential/general_copy.h", + "cuda/include/thrust/system/detail/sequential/insertion_sort.h", + "cuda/include/thrust/system/detail/sequential/remove.h", + "cuda/include/thrust/system/detail/sequential/tabulate.h", + "cuda/include/thrust/system/detail/sequential/for_each.h", + "cuda/include/thrust/system/detail/sequential/reduce_by_key.h", + "cuda/include/thrust/system/detail/sequential/reduce.h", + "cuda/include/thrust/system/detail/sequential/equal.h", + "cuda/include/thrust/system/detail/sequential/stable_radix_sort.h", + "cuda/include/thrust/system/detail/sequential/copy.inl", + "cuda/include/thrust/system/detail/sequential/copy.h", + "cuda/include/thrust/system/detail/sequential/swap_ranges.h", + "cuda/include/thrust/system/detail/sequential/uninitialized_copy.h", + "cuda/include/thrust/system/detail/sequential/binary_search.h", + "cuda/include/thrust/system/detail/sequential/set_operations.h", + "cuda/include/thrust/system/detail/sequential/mismatch.h", + "cuda/include/thrust/system/detail/sequential/extrema.h", + "cuda/include/thrust/system/detail/sequential/count.h", + "cuda/include/thrust/system/detail/sequential/trivial_copy.h", + "cuda/include/thrust/system/detail/sequential/replace.h", + "cuda/include/thrust/system/detail/sequential/get_value.h", + "cuda/include/thrust/system/detail/sequential/inner_product.h", + "cuda/include/thrust/system/detail/sequential/copy_if.h", + "cuda/include/thrust/system/detail/sequential/logical.h", + "cuda/include/thrust/system/detail/sequential/iter_swap.h", + "cuda/include/thrust/system/detail/sequential/malloc_and_free.h", + "cuda/include/thrust/system/detail/sequential/fill.h", + "cuda/include/thrust/system/detail/sequential/transform.h", + "cuda/include/thrust/system/detail/error_condition.inl", + "cuda/include/thrust/system/detail/internal/decompose.h", + "cuda/include/thrust/system/detail/error_code.inl", + "cuda/include/thrust/system/detail/generic/transform_scan.h", + "cuda/include/thrust/system/detail/generic/memory.inl", + "cuda/include/thrust/system/detail/generic/transform.inl", + "cuda/include/thrust/system/detail/generic/binary_search.inl", + "cuda/include/thrust/system/detail/generic/scan_by_key.inl", + "cuda/include/thrust/system/detail/generic/unique_by_key.h", + "cuda/include/thrust/system/detail/generic/inner_product.inl", + "cuda/include/thrust/system/detail/generic/select_system.h", + "cuda/include/thrust/system/detail/generic/sequence.inl", + "cuda/include/thrust/system/detail/generic/sort.inl", + "cuda/include/thrust/system/detail/generic/equal.inl", + "cuda/include/thrust/system/detail/generic/partition.h", + "cuda/include/thrust/system/detail/generic/unique.h", + "cuda/include/thrust/system/detail/generic/adjacent_difference.h", + "cuda/include/thrust/system/detail/generic/tag.h", + "cuda/include/thrust/system/detail/generic/unique_by_key.inl", + "cuda/include/thrust/system/detail/generic/sequence.h", + "cuda/include/thrust/system/detail/generic/type_traits.h", + "cuda/include/thrust/system/detail/generic/merge.h", + "cuda/include/thrust/system/detail/generic/reverse.inl", + "cuda/include/thrust/system/detail/generic/tabulate.inl", + "cuda/include/thrust/system/detail/generic/unique.inl", + "cuda/include/thrust/system/detail/generic/scatter.inl", + "cuda/include/thrust/system/detail/generic/set_operations.inl", + "cuda/include/thrust/system/detail/generic/copy_if.inl", + "cuda/include/thrust/system/detail/generic/transform_reduce.h", + "cuda/include/thrust/system/detail/generic/transform_scan.inl", + "cuda/include/thrust/system/detail/generic/gather.h", + "cuda/include/thrust/system/detail/generic/reduce_by_key.inl", + "cuda/include/thrust/system/detail/generic/transform_reduce.inl", + "cuda/include/thrust/system/detail/generic/sort.h", + "cuda/include/thrust/system/detail/generic/distance.inl", + "cuda/include/thrust/system/detail/generic/scan.h", + "cuda/include/thrust/system/detail/generic/temporary_buffer.h", + "cuda/include/thrust/system/detail/generic/reduce.inl", + "cuda/include/thrust/system/detail/generic/scan_by_key.h", + "cuda/include/thrust/system/detail/generic/reverse.h", + "cuda/include/thrust/system/detail/generic/temporary_buffer.inl", + "cuda/include/thrust/system/detail/generic/scatter.h", + "cuda/include/thrust/system/detail/generic/generate.inl", + "cuda/include/thrust/system/detail/generic/adjacent_difference.inl", + "cuda/include/thrust/system/detail/generic/remove.inl", + "cuda/include/thrust/system/detail/generic/advance.h", + "cuda/include/thrust/system/detail/generic/find.h", + "cuda/include/thrust/system/detail/generic/merge.inl", + "cuda/include/thrust/system/detail/generic/scalar/binary_search.inl", + "cuda/include/thrust/system/detail/generic/scalar/binary_search.h", + "cuda/include/thrust/system/detail/generic/extrema.inl", + "cuda/include/thrust/system/detail/generic/generate.h", + "cuda/include/thrust/system/detail/generic/uninitialized_fill.h", + "cuda/include/thrust/system/detail/generic/count.inl", + "cuda/include/thrust/system/detail/generic/remove.h", + "cuda/include/thrust/system/detail/generic/uninitialized_copy.inl", + "cuda/include/thrust/system/detail/generic/tabulate.h", + "cuda/include/thrust/system/detail/generic/for_each.h", + "cuda/include/thrust/system/detail/generic/distance.h", + "cuda/include/thrust/system/detail/generic/swap_ranges.inl", + "cuda/include/thrust/system/detail/generic/reduce_by_key.h", + "cuda/include/thrust/system/detail/generic/reduce.h", + "cuda/include/thrust/system/detail/generic/equal.h", + "cuda/include/thrust/system/detail/generic/mismatch.inl", + "cuda/include/thrust/system/detail/generic/copy.inl", + "cuda/include/thrust/system/detail/generic/copy.h", + "cuda/include/thrust/system/detail/generic/swap_ranges.h", + "cuda/include/thrust/system/detail/generic/uninitialized_copy.h", + "cuda/include/thrust/system/detail/generic/binary_search.h", + "cuda/include/thrust/system/detail/generic/set_operations.h", + "cuda/include/thrust/system/detail/generic/uninitialized_fill.inl", + "cuda/include/thrust/system/detail/generic/mismatch.h", + "cuda/include/thrust/system/detail/generic/scan.inl", + "cuda/include/thrust/system/detail/generic/gather.inl", + "cuda/include/thrust/system/detail/generic/extrema.h", + "cuda/include/thrust/system/detail/generic/count.h", + "cuda/include/thrust/system/detail/generic/replace.h", + "cuda/include/thrust/system/detail/generic/inner_product.h", + "cuda/include/thrust/system/detail/generic/copy_if.h", + "cuda/include/thrust/system/detail/generic/logical.h", + "cuda/include/thrust/system/detail/generic/partition.inl", + "cuda/include/thrust/system/detail/generic/memory.h", + "cuda/include/thrust/system/detail/generic/find.inl", + "cuda/include/thrust/system/detail/generic/replace.inl", + "cuda/include/thrust/system/detail/generic/advance.inl", + "cuda/include/thrust/system/detail/generic/fill.h", + "cuda/include/thrust/system/detail/generic/transform.h", + "cuda/include/thrust/system/detail/system_error.inl", + "cuda/include/thrust/system/omp/execution_policy.h", + "cuda/include/thrust/system/omp/vector.h", + "cuda/include/thrust/system/omp/detail/transform_scan.h", + "cuda/include/thrust/system/omp/detail/memory.inl", + "cuda/include/thrust/system/omp/detail/reduce_intervals.inl", + "cuda/include/thrust/system/omp/detail/unique_by_key.h", + "cuda/include/thrust/system/omp/detail/sort.inl", + "cuda/include/thrust/system/omp/detail/partition.h", + "cuda/include/thrust/system/omp/detail/unique.h", + "cuda/include/thrust/system/omp/detail/execution_policy.h", + "cuda/include/thrust/system/omp/detail/adjacent_difference.h", + "cuda/include/thrust/system/omp/detail/unique_by_key.inl", + "cuda/include/thrust/system/omp/detail/sequence.h", + "cuda/include/thrust/system/omp/detail/merge.h", + "cuda/include/thrust/system/omp/detail/unique.inl", + "cuda/include/thrust/system/omp/detail/copy_if.inl", + "cuda/include/thrust/system/omp/detail/transform_reduce.h", + "cuda/include/thrust/system/omp/detail/gather.h", + "cuda/include/thrust/system/omp/detail/reduce_by_key.inl", + "cuda/include/thrust/system/omp/detail/sort.h", + "cuda/include/thrust/system/omp/detail/scan.h", + "cuda/include/thrust/system/omp/detail/temporary_buffer.h", + "cuda/include/thrust/system/omp/detail/default_decomposition.h", + "cuda/include/thrust/system/omp/detail/reduce.inl", + "cuda/include/thrust/system/omp/detail/scan_by_key.h", + "cuda/include/thrust/system/omp/detail/reverse.h", + "cuda/include/thrust/system/omp/detail/assign_value.h", + "cuda/include/thrust/system/omp/detail/scatter.h", + "cuda/include/thrust/system/omp/detail/for_each.inl", + "cuda/include/thrust/system/omp/detail/default_decomposition.inl", + "cuda/include/thrust/system/omp/detail/remove.inl", + "cuda/include/thrust/system/omp/detail/vector.inl", + "cuda/include/thrust/system/omp/detail/find.h", + "cuda/include/thrust/system/omp/detail/generate.h", + "cuda/include/thrust/system/omp/detail/uninitialized_fill.h", + "cuda/include/thrust/system/omp/detail/remove.h", + "cuda/include/thrust/system/omp/detail/tabulate.h", + "cuda/include/thrust/system/omp/detail/for_each.h", + "cuda/include/thrust/system/omp/detail/reduce_by_key.h", + "cuda/include/thrust/system/omp/detail/reduce.h", + "cuda/include/thrust/system/omp/detail/equal.h", + "cuda/include/thrust/system/omp/detail/copy.inl", + "cuda/include/thrust/system/omp/detail/copy.h", + "cuda/include/thrust/system/omp/detail/swap_ranges.h", + "cuda/include/thrust/system/omp/detail/uninitialized_copy.h", + "cuda/include/thrust/system/omp/detail/binary_search.h", + "cuda/include/thrust/system/omp/detail/set_operations.h", + "cuda/include/thrust/system/omp/detail/mismatch.h", + "cuda/include/thrust/system/omp/detail/extrema.h", + "cuda/include/thrust/system/omp/detail/count.h", + "cuda/include/thrust/system/omp/detail/replace.h", + "cuda/include/thrust/system/omp/detail/get_value.h", + "cuda/include/thrust/system/omp/detail/inner_product.h", + "cuda/include/thrust/system/omp/detail/copy_if.h", + "cuda/include/thrust/system/omp/detail/logical.h", + "cuda/include/thrust/system/omp/detail/partition.inl", + "cuda/include/thrust/system/omp/detail/iter_swap.h", + "cuda/include/thrust/system/omp/detail/par.h", + "cuda/include/thrust/system/omp/detail/reduce_intervals.h", + "cuda/include/thrust/system/omp/detail/malloc_and_free.h", + "cuda/include/thrust/system/omp/detail/fill.h", + "cuda/include/thrust/system/omp/detail/transform.h", + "cuda/include/thrust/system/omp/memory.h", + "cuda/include/thrust/system/tbb/execution_policy.h", + "cuda/include/thrust/system/tbb/vector.h", + "cuda/include/thrust/system/tbb/detail/transform_scan.h", + "cuda/include/thrust/system/tbb/detail/memory.inl", + "cuda/include/thrust/system/tbb/detail/unique_by_key.h", + "cuda/include/thrust/system/tbb/detail/sort.inl", + "cuda/include/thrust/system/tbb/detail/partition.h", + "cuda/include/thrust/system/tbb/detail/unique.h", + "cuda/include/thrust/system/tbb/detail/execution_policy.h", + "cuda/include/thrust/system/tbb/detail/adjacent_difference.h", + "cuda/include/thrust/system/tbb/detail/unique_by_key.inl", + "cuda/include/thrust/system/tbb/detail/sequence.h", + "cuda/include/thrust/system/tbb/detail/merge.h", + "cuda/include/thrust/system/tbb/detail/unique.inl", + "cuda/include/thrust/system/tbb/detail/copy_if.inl", + "cuda/include/thrust/system/tbb/detail/transform_reduce.h", + "cuda/include/thrust/system/tbb/detail/gather.h", + "cuda/include/thrust/system/tbb/detail/reduce_by_key.inl", + "cuda/include/thrust/system/tbb/detail/sort.h", + "cuda/include/thrust/system/tbb/detail/scan.h", + "cuda/include/thrust/system/tbb/detail/temporary_buffer.h", + "cuda/include/thrust/system/tbb/detail/reduce.inl", + "cuda/include/thrust/system/tbb/detail/scan_by_key.h", + "cuda/include/thrust/system/tbb/detail/reverse.h", + "cuda/include/thrust/system/tbb/detail/assign_value.h", + "cuda/include/thrust/system/tbb/detail/scatter.h", + "cuda/include/thrust/system/tbb/detail/for_each.inl", + "cuda/include/thrust/system/tbb/detail/remove.inl", + "cuda/include/thrust/system/tbb/detail/vector.inl", + "cuda/include/thrust/system/tbb/detail/find.h", + "cuda/include/thrust/system/tbb/detail/merge.inl", + "cuda/include/thrust/system/tbb/detail/generate.h", + "cuda/include/thrust/system/tbb/detail/uninitialized_fill.h", + "cuda/include/thrust/system/tbb/detail/remove.h", + "cuda/include/thrust/system/tbb/detail/tabulate.h", + "cuda/include/thrust/system/tbb/detail/for_each.h", + "cuda/include/thrust/system/tbb/detail/reduce_by_key.h", + "cuda/include/thrust/system/tbb/detail/reduce.h", + "cuda/include/thrust/system/tbb/detail/equal.h", + "cuda/include/thrust/system/tbb/detail/copy.inl", + "cuda/include/thrust/system/tbb/detail/copy.h", + "cuda/include/thrust/system/tbb/detail/swap_ranges.h", + "cuda/include/thrust/system/tbb/detail/uninitialized_copy.h", + "cuda/include/thrust/system/tbb/detail/binary_search.h", + "cuda/include/thrust/system/tbb/detail/set_operations.h", + "cuda/include/thrust/system/tbb/detail/mismatch.h", + "cuda/include/thrust/system/tbb/detail/scan.inl", + "cuda/include/thrust/system/tbb/detail/extrema.h", + "cuda/include/thrust/system/tbb/detail/count.h", + "cuda/include/thrust/system/tbb/detail/replace.h", + "cuda/include/thrust/system/tbb/detail/get_value.h", + "cuda/include/thrust/system/tbb/detail/inner_product.h", + "cuda/include/thrust/system/tbb/detail/copy_if.h", + "cuda/include/thrust/system/tbb/detail/logical.h", + "cuda/include/thrust/system/tbb/detail/partition.inl", + "cuda/include/thrust/system/tbb/detail/iter_swap.h", + "cuda/include/thrust/system/tbb/detail/par.h", + "cuda/include/thrust/system/tbb/detail/reduce_intervals.h", + "cuda/include/thrust/system/tbb/detail/malloc_and_free.h", + "cuda/include/thrust/system/tbb/detail/fill.h", + "cuda/include/thrust/system/tbb/detail/transform.h", + "cuda/include/thrust/system/tbb/memory.h", + "cuda/include/thrust/system/error_code.h", + "cuda/include/thrust/system/cpp/execution_policy.h", + "cuda/include/thrust/system/cpp/vector.h", + "cuda/include/thrust/system/cpp/detail/transform_scan.h", + "cuda/include/thrust/system/cpp/detail/memory.inl", + "cuda/include/thrust/system/cpp/detail/unique_by_key.h", + "cuda/include/thrust/system/cpp/detail/partition.h", + "cuda/include/thrust/system/cpp/detail/unique.h", + "cuda/include/thrust/system/cpp/detail/execution_policy.h", + "cuda/include/thrust/system/cpp/detail/adjacent_difference.h", + "cuda/include/thrust/system/cpp/detail/sequence.h", + "cuda/include/thrust/system/cpp/detail/merge.h", + "cuda/include/thrust/system/cpp/detail/transform_reduce.h", + "cuda/include/thrust/system/cpp/detail/gather.h", + "cuda/include/thrust/system/cpp/detail/sort.h", + "cuda/include/thrust/system/cpp/detail/scan.h", + "cuda/include/thrust/system/cpp/detail/temporary_buffer.h", + "cuda/include/thrust/system/cpp/detail/scan_by_key.h", + "cuda/include/thrust/system/cpp/detail/reverse.h", + "cuda/include/thrust/system/cpp/detail/assign_value.h", + "cuda/include/thrust/system/cpp/detail/scatter.h", + "cuda/include/thrust/system/cpp/detail/vector.inl", + "cuda/include/thrust/system/cpp/detail/find.h", + "cuda/include/thrust/system/cpp/detail/generate.h", + "cuda/include/thrust/system/cpp/detail/uninitialized_fill.h", + "cuda/include/thrust/system/cpp/detail/remove.h", + "cuda/include/thrust/system/cpp/detail/tabulate.h", + "cuda/include/thrust/system/cpp/detail/for_each.h", + "cuda/include/thrust/system/cpp/detail/reduce_by_key.h", + "cuda/include/thrust/system/cpp/detail/reduce.h", + "cuda/include/thrust/system/cpp/detail/equal.h", + "cuda/include/thrust/system/cpp/detail/copy.h", + "cuda/include/thrust/system/cpp/detail/swap_ranges.h", + "cuda/include/thrust/system/cpp/detail/uninitialized_copy.h", + "cuda/include/thrust/system/cpp/detail/binary_search.h", + "cuda/include/thrust/system/cpp/detail/set_operations.h", + "cuda/include/thrust/system/cpp/detail/mismatch.h", + "cuda/include/thrust/system/cpp/detail/extrema.h", + "cuda/include/thrust/system/cpp/detail/count.h", + "cuda/include/thrust/system/cpp/detail/replace.h", + "cuda/include/thrust/system/cpp/detail/get_value.h", + "cuda/include/thrust/system/cpp/detail/inner_product.h", + "cuda/include/thrust/system/cpp/detail/copy_if.h", + "cuda/include/thrust/system/cpp/detail/logical.h", + "cuda/include/thrust/system/cpp/detail/iter_swap.h", + "cuda/include/thrust/system/cpp/detail/par.h", + "cuda/include/thrust/system/cpp/detail/malloc_and_free.h", + "cuda/include/thrust/system/cpp/detail/fill.h", + "cuda/include/thrust/system/cpp/detail/transform.h", + "cuda/include/thrust/system/cpp/memory.h", + "cuda/include/thrust/system/cuda/execution_policy.h", + "cuda/include/thrust/system/cuda/vector.h", + "cuda/include/thrust/system/cuda/error.h", + "cuda/include/thrust/system/cuda/detail/copy_device_to_device.h", + "cuda/include/thrust/system/cuda/detail/transform_scan.h", + "cuda/include/thrust/system/cuda/detail/memory.inl", + "cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_device.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_type.cuh", + "cuda/include/thrust/system/cuda/detail/cub/host/spinlock.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh", + "cuda/include/thrust/system/cuda/detail/cub/cub.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_shift.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh", + "cuda/include/thrust/system/cuda/detail/reduce_intervals.inl", + "cuda/include/thrust/system/cuda/detail/copy_cross_system.inl", + "cuda/include/thrust/system/cuda/detail/unique_by_key.h", + "cuda/include/thrust/system/cuda/detail/bulk.h", + "cuda/include/thrust/system/cuda/detail/sort.inl", + "cuda/include/thrust/system/cuda/detail/partition.h", + "cuda/include/thrust/system/cuda/detail/unique.h", + "cuda/include/thrust/system/cuda/detail/execution_policy.h", + "cuda/include/thrust/system/cuda/detail/cuda_launch_config.h", + "cuda/include/thrust/system/cuda/detail/cub.h", + "cuda/include/thrust/system/cuda/detail/adjacent_difference.h", + "cuda/include/thrust/system/cuda/detail/sequence.h", + "cuda/include/thrust/system/cuda/detail/merge.h", + "cuda/include/thrust/system/cuda/detail/set_symmetric_difference.inl", + "cuda/include/thrust/system/cuda/detail/copy_if.inl", + "cuda/include/thrust/system/cuda/detail/transform_reduce.h", + "cuda/include/thrust/system/cuda/detail/error.inl", + "cuda/include/thrust/system/cuda/detail/gather.h", + "cuda/include/thrust/system/cuda/detail/reduce_by_key.inl", + "cuda/include/thrust/system/cuda/detail/sort.h", + "cuda/include/thrust/system/cuda/detail/synchronize.h", + "cuda/include/thrust/system/cuda/detail/scan.h", + "cuda/include/thrust/system/cuda/detail/temporary_indirect_permutation.h", + "cuda/include/thrust/system/cuda/detail/extern_shared_ptr.h", + "cuda/include/thrust/system/cuda/detail/detail/set_operation.inl", + "cuda/include/thrust/system/cuda/detail/detail/balanced_path.h", + "cuda/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h", + "cuda/include/thrust/system/cuda/detail/detail/set_operation.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl", + "cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.h", + "cuda/include/thrust/system/cuda/detail/detail/launch_closure.inl", + "cuda/include/thrust/system/cuda/detail/detail/merge.h", + "cuda/include/thrust/system/cuda/detail/detail/alignment.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl", + "cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.h", + "cuda/include/thrust/system/cuda/detail/detail/launch_calculator.inl", + "cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl", + "cuda/include/thrust/system/cuda/detail/detail/launch_closure.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.h", + "cuda/include/thrust/system/cuda/detail/detail/uninitialized.h", + "cuda/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h", + "cuda/include/thrust/system/cuda/detail/detail/launch_calculator.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.inl", + "cuda/include/thrust/system/cuda/detail/temporary_buffer.h", + "cuda/include/thrust/system/cuda/detail/default_decomposition.h", + "cuda/include/thrust/system/cuda/detail/reduce.inl", + "cuda/include/thrust/system/cuda/detail/scan_by_key.h", + "cuda/include/thrust/system/cuda/detail/reverse.h", + "cuda/include/thrust/system/cuda/detail/assign_value.h", + "cuda/include/thrust/system/cuda/detail/scatter.h", + "cuda/include/thrust/system/cuda/detail/reduce_intervals.hpp", + "cuda/include/thrust/system/cuda/detail/for_each.inl", + "cuda/include/thrust/system/cuda/detail/default_decomposition.inl", + "cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h", + "cuda/include/thrust/system/cuda/detail/adjacent_difference.inl", + "cuda/include/thrust/system/cuda/detail/vector.inl", + "cuda/include/thrust/system/cuda/detail/throw_on_error.h", + "cuda/include/thrust/system/cuda/detail/find.h", + "cuda/include/thrust/system/cuda/detail/terminate.h", + "cuda/include/thrust/system/cuda/detail/merge.inl", + "cuda/include/thrust/system/cuda/detail/trivial_copy.inl", + "cuda/include/thrust/system/cuda/detail/generate.h", + "cuda/include/thrust/system/cuda/detail/execute_on_stream.h", + "cuda/include/thrust/system/cuda/detail/uninitialized_fill.h", + "cuda/include/thrust/system/cuda/detail/remove.h", + "cuda/include/thrust/system/cuda/detail/tabulate.h", + "cuda/include/thrust/system/cuda/detail/for_each.h", + "cuda/include/thrust/system/cuda/detail/reduce_by_key.h", + "cuda/include/thrust/system/cuda/detail/decomposition.h", + "cuda/include/thrust/system/cuda/detail/reduce.h", + "cuda/include/thrust/system/cuda/detail/equal.h", + "cuda/include/thrust/system/cuda/detail/runtime_introspection.h", + "cuda/include/thrust/system/cuda/detail/copy.inl", + "cuda/include/thrust/system/cuda/detail/copy.h", + "cuda/include/thrust/system/cuda/detail/swap_ranges.h", + "cuda/include/thrust/system/cuda/detail/uninitialized_copy.h", + "cuda/include/thrust/system/cuda/detail/binary_search.h", + "cuda/include/thrust/system/cuda/detail/runtime_introspection.inl", + "cuda/include/thrust/system/cuda/detail/set_operations.h", + "cuda/include/thrust/system/cuda/detail/mismatch.h", + "cuda/include/thrust/system/cuda/detail/scan.inl", + "cuda/include/thrust/system/cuda/detail/synchronize.inl", + "cuda/include/thrust/system/cuda/detail/extrema.h", + "cuda/include/thrust/system/cuda/detail/set_union.inl", + "cuda/include/thrust/system/cuda/detail/set_intersection.inl", + "cuda/include/thrust/system/cuda/detail/count.h", + "cuda/include/thrust/system/cuda/detail/trivial_copy.h", + "cuda/include/thrust/system/cuda/detail/copy_device_to_device.inl", + "cuda/include/thrust/system/cuda/detail/replace.h", + "cuda/include/thrust/system/cuda/detail/bulk/malloc.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/config.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/closure.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl", + "cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/async.inl", + "cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/iterator.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/bulk.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/execution_policy.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/uninitialized.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/async.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/future.hpp", + "cuda/include/thrust/system/cuda/detail/guarded_driver_types.h", + "cuda/include/thrust/system/cuda/detail/get_value.h", + "cuda/include/thrust/system/cuda/detail/inner_product.h", + "cuda/include/thrust/system/cuda/detail/copy_if.h", + "cuda/include/thrust/system/cuda/detail/logical.h", + "cuda/include/thrust/system/cuda/detail/iter_swap.h", + "cuda/include/thrust/system/cuda/detail/block/merge.h", + "cuda/include/thrust/system/cuda/detail/block/inclusive_scan.h", + "cuda/include/thrust/system/cuda/detail/block/merge.inl", + "cuda/include/thrust/system/cuda/detail/block/merging_sort.h", + "cuda/include/thrust/system/cuda/detail/block/exclusive_scan.h", + "cuda/include/thrust/system/cuda/detail/block/reduce.h", + "cuda/include/thrust/system/cuda/detail/block/copy.h", + "cuda/include/thrust/system/cuda/detail/block/odd_even_sort.h", + "cuda/include/thrust/system/cuda/detail/par.h", + "cuda/include/thrust/system/cuda/detail/copy_cross_system.h", + "cuda/include/thrust/system/cuda/detail/reduce_intervals.h", + "cuda/include/thrust/system/cuda/detail/malloc_and_free.h", + "cuda/include/thrust/system/cuda/detail/fill.h", + "cuda/include/thrust/system/cuda/detail/set_difference.inl", + "cuda/include/thrust/system/cuda/detail/transform.h", + "cuda/include/thrust/system/cuda/experimental/pinned_allocator.h", + "cuda/include/thrust/system/cuda/memory.h", + "cuda/include/thrust/remove.h", + "cuda/include/thrust/tabulate.h", + "cuda/include/thrust/for_each.h", + "cuda/include/thrust/distance.h", + "cuda/include/thrust/reduce.h", + "cuda/include/thrust/equal.h", + "cuda/include/thrust/complex.h", + "cuda/include/thrust/device_allocator.h", + "cuda/include/thrust/copy.h", + "cuda/include/thrust/uninitialized_copy.h", + "cuda/include/thrust/device_reference.h", + "cuda/include/thrust/binary_search.h", + "cuda/include/thrust/set_operations.h", + "cuda/include/thrust/swap.h", + "cuda/include/thrust/mismatch.h", + "cuda/include/thrust/extrema.h", + "cuda/include/thrust/count.h", + "cuda/include/thrust/device_free.h", + "cuda/include/thrust/random/discard_block_engine.h", + "cuda/include/thrust/random/normal_distribution.h", + "cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h", + "cuda/include/thrust/random/detail/subtract_with_carry_engine.inl", + "cuda/include/thrust/random/detail/xor_combine_engine_max.h", + "cuda/include/thrust/random/detail/linear_congruential_engine_discard.h", + "cuda/include/thrust/random/detail/uniform_int_distribution.inl", + "cuda/include/thrust/random/detail/discard_block_engine.inl", + "cuda/include/thrust/random/detail/uniform_real_distribution.inl", + "cuda/include/thrust/random/detail/random_core_access.h", + "cuda/include/thrust/random/detail/mod.h", + "cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl", + "cuda/include/thrust/random/detail/linear_congruential_engine.inl", + "cuda/include/thrust/random/detail/xor_combine_engine.inl", + "cuda/include/thrust/random/detail/normal_distribution.inl", + "cuda/include/thrust/random/detail/normal_distribution_base.h", + "cuda/include/thrust/random/uniform_int_distribution.h", + "cuda/include/thrust/random/linear_feedback_shift_engine.h", + "cuda/include/thrust/random/xor_combine_engine.h", + "cuda/include/thrust/random/subtract_with_carry_engine.h", + "cuda/include/thrust/random/linear_congruential_engine.h", + "cuda/include/thrust/random/uniform_real_distribution.h", + "cuda/include/thrust/functional.h", + "cuda/include/thrust/replace.h", + "cuda/include/thrust/device_new_allocator.h", + "cuda/include/thrust/host_vector.h", + "cuda/include/thrust/version.h", + "cuda/include/thrust/inner_product.h", + "cuda/include/thrust/iterator/iterator_traits.h", + "cuda/include/thrust/iterator/discard_iterator.h", + "cuda/include/thrust/iterator/retag.h", + "cuda/include/thrust/iterator/permutation_iterator.h", + "cuda/include/thrust/iterator/transform_iterator.h", + "cuda/include/thrust/iterator/detail/reverse_iterator.inl", + "cuda/include/thrust/iterator/detail/zip_iterator.inl", + "cuda/include/thrust/iterator/detail/counting_iterator.inl", + "cuda/include/thrust/iterator/detail/distance_from_result.h", + "cuda/include/thrust/iterator/detail/host_system_tag.h", + "cuda/include/thrust/iterator/detail/iterator_traversal_tags.h", + "cuda/include/thrust/iterator/detail/retag.h", + "cuda/include/thrust/iterator/detail/tagged_iterator.h", + "cuda/include/thrust/iterator/detail/iterator_traits.inl", + "cuda/include/thrust/iterator/detail/minimum_category.h", + "cuda/include/thrust/iterator/detail/discard_iterator_base.h", + "cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h", + "cuda/include/thrust/iterator/detail/zip_iterator_base.h", + "cuda/include/thrust/iterator/detail/normal_iterator.h", + "cuda/include/thrust/iterator/detail/join_iterator.h", + "cuda/include/thrust/iterator/detail/device_system_tag.h", + "cuda/include/thrust/iterator/detail/universal_categories.h", + "cuda/include/thrust/iterator/detail/reverse_iterator_base.h", + "cuda/include/thrust/iterator/detail/minimum_system.h", + "cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h", + "cuda/include/thrust/iterator/detail/is_iterator_category.h", + "cuda/include/thrust/iterator/detail/permutation_iterator_base.h", + "cuda/include/thrust/iterator/detail/any_assign.h", + "cuda/include/thrust/iterator/detail/any_system_tag.h", + "cuda/include/thrust/iterator/detail/is_trivial_iterator.h", + "cuda/include/thrust/iterator/detail/iterator_category_to_system.h", + "cuda/include/thrust/iterator/detail/iterator_adaptor_base.h", + "cuda/include/thrust/iterator/detail/constant_iterator_base.h", + "cuda/include/thrust/iterator/detail/transform_iterator.inl", + "cuda/include/thrust/iterator/detail/iterator_facade_category.h", + "cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h", + "cuda/include/thrust/iterator/constant_iterator.h", + "cuda/include/thrust/iterator/counting_iterator.h", + "cuda/include/thrust/iterator/iterator_adaptor.h", + "cuda/include/thrust/iterator/iterator_facade.h", + "cuda/include/thrust/iterator/iterator_categories.h", + "cuda/include/thrust/iterator/reverse_iterator.h", + "cuda/include/thrust/iterator/zip_iterator.h", + "cuda/include/thrust/logical.h", + "cuda/include/thrust/tuple.h", + "cuda/include/thrust/memory.h", + "cuda/include/thrust/random.h", + "cuda/include/thrust/fill.h", + "cuda/include/thrust/transform.h", + "cuda/include/texture_types.h", + "cuda/include/nppversion.h", + "cuda/include/cuda_texture_types.h", + "cuda/include/fatbinary.h", + "cuda/include/cublasXt.h", + "cuda/include/cuda_fp16.h", + "cuda/include/vector_functions.h", + "cuda/include/cusparse.h", + "cuda/include/nppi_filtering_functions.h", + "cuda/include/nppi_morphological_operations.h", + "cuda/include/sobol_direction_vectors.h", + "cuda/include/nvblas.h", + "cuda/include/curand_mtgp32dc_p_11213.h", + "cuda/include/nvcuvid.h", + "cuda/include/cuda_runtime_api.h", + "cuda/include/curand_mtgp32_kernel.h", + "cuda/include/cublas_v2.h", + "cuda/include/builtin_types.h", + "cuda/include/nppi_geometry_transforms.h", + "cuda/include/npps_support_functions.h", + "cuda/include/cufftw.h", + "cuda/include/cuda_device_runtime_api.h", + "cuda/include/sm_30_intrinsics.hpp", + "cuda/include/vector_types.h", + "cuda/include/sm_35_atomic_functions.h", + "cuda/include/sm_20_intrinsics.h", + "cuda/include/driver_types.h", + "cuda/include/nvToolsExtCudaRt.h", + "cuda/include/curand_globals.h", + "cuda/include/device_atomic_functions.h", + "cuda/include/surface_types.h", + "cuda/include/nvrtc.h", + "cuda/include/nppdefs.h", + "cuda/include/sm_60_atomic_functions.h", + "cuda/include/driver_functions.h", + "cuda/include/cusolver_common.h", + "cuda/include/cublas.h", + "cuda/include/curand_lognormal.h", + "cuda/include/device_atomic_functions.hpp", + "cuda/include/crt/device_runtime.h", + "cuda/include/crt/storage_class.h", + "cuda/include/crt/func_macro.h", + "cuda/include/crt/host_runtime.h", + "cuda/include/nppi_arithmetic_and_logical_operations.h", + "cuda/include/npps_arithmetic_and_logical_operations.h", + "cuda/include/nppi_computer_vision.h", + "cuda/include/surface_functions.hpp", + "cuda/include/surface_functions.h", + "cuda/include/curand_normal_static.h", + "cuda/include/curand.h", + "cuda/include/math_functions_dbl_ptx3.h", + "cuda/include/curand_philox4x32_x.h", + "cuda/include/nppi_threshold_and_compare_operations.h", + "cuda/include/nvml.h", + "cuda/include/npps.h", + "cuda/include/cuda_vdpau_interop.h", + "cuda/include/sm_61_intrinsics.hpp", + "cuda/include/cublas_api.h", + "cuda/include/nppi_color_conversion.h", + "cuda/include/math_functions_dbl_ptx3.hpp", + "cuda/include/nppcore.h", + "cuda/include/cudaGL.h", + "cuda/include/fatBinaryCtl.h", + "cuda/include/npps_statistics_functions.h", + "cuda/include/cudaVDPAU.h", + "cuda/include/curand_poisson.h", + "cuda/include/cusolverDn.h", + "cuda/include/cuda_profiler_api.h", + "cuda/include/sm_20_atomic_functions.h", + "cuda/include/nvfunctional", + ], + cmd = """ +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-8.0/include/math_functions.hpp" "$(@D)/cuda/include/math_functions.hpp" && cp "/usr/local/cuda-8.0/include/cufft.h" "$(@D)/cuda/include/cufft.h" && cp "/usr/local/cuda-8.0/include/nvgraph.h" "$(@D)/cuda/include/nvgraph.h" && cp "/usr/local/cuda-8.0/include/curand_normal.h" "$(@D)/cuda/include/curand_normal.h" && cp "/usr/local/cuda-8.0/include/curand_uniform.h" "$(@D)/cuda/include/curand_uniform.h" && cp "/usr/local/cuda-8.0/include/nppi_data_exchange_and_initialization.h" "$(@D)/cuda/include/nppi_data_exchange_and_initialization.h" && cp "/usr/local/cuda-8.0/include/cuda_gl_interop.h" "$(@D)/cuda/include/cuda_gl_interop.h" && cp "/usr/local/cuda-8.0/include/nppi_compression_functions.h" "$(@D)/cuda/include/nppi_compression_functions.h" && cp "/usr/local/cuda-8.0/include/npp.h" "$(@D)/cuda/include/npp.h" && cp "/usr/local/cuda-8.0/include/cuda.h" "$(@D)/cuda/include/cuda.h" && cp "/usr/local/cuda-8.0/include/nppi_statistics_functions.h" "$(@D)/cuda/include/nppi_statistics_functions.h" && cp "/usr/local/cuda-8.0/include/vector_functions.hpp" "$(@D)/cuda/include/vector_functions.hpp" && cp "/usr/local/cuda-8.0/include/sm_32_intrinsics.hpp" "$(@D)/cuda/include/sm_32_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/sm_32_intrinsics.h" "$(@D)/cuda/include/sm_32_intrinsics.h" && cp "/usr/local/cuda-8.0/include/curand_discrete.h" "$(@D)/cuda/include/curand_discrete.h" && cp "/usr/local/cuda-8.0/include/cuda_runtime.h" "$(@D)/cuda/include/cuda_runtime.h" && cp "/usr/local/cuda-8.0/include/cufftXt.h" "$(@D)/cuda/include/cufftXt.h" && cp "/usr/local/cuda-8.0/include/sm_61_intrinsics.h" "$(@D)/cuda/include/sm_61_intrinsics.h" && cp "/usr/local/cuda-8.0/include/texture_fetch_functions.h" "$(@D)/cuda/include/texture_fetch_functions.h" && cp "/usr/local/cuda-8.0/include/curand_mrg32k3a.h" "$(@D)/cuda/include/curand_mrg32k3a.h" && cp "/usr/local/cuda-8.0/include/host_defines.h" "$(@D)/cuda/include/host_defines.h" && cp "/usr/local/cuda-8.0/include/common_functions.h" "$(@D)/cuda/include/common_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_support_functions.h" "$(@D)/cuda/include/nppi_support_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_linear_transforms.h" "$(@D)/cuda/include/nppi_linear_transforms.h" && cp "/usr/local/cuda-8.0/include/device_double_functions.hpp" "$(@D)/cuda/include/device_double_functions.hpp" && cp "/usr/local/cuda-8.0/include/math_constants.h" "$(@D)/cuda/include/math_constants.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtSync.h" "$(@D)/cuda/include/nvToolsExtSync.h" && cp "/usr/local/cuda-8.0/include/npps_initialization.h" "$(@D)/cuda/include/npps_initialization.h" && cp "/usr/local/cuda-8.0/include/cusolverSp_LOWLEVEL_PREVIEW.h" "$(@D)/cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h" && cp "/usr/local/cuda-8.0/include/texture_indirect_functions.hpp" "$(@D)/cuda/include/texture_indirect_functions.hpp" && cp "/usr/local/cuda-8.0/include/cudaProfiler.h" "$(@D)/cuda/include/cudaProfiler.h" && cp "/usr/local/cuda-8.0/include/npps_filtering_functions.h" "$(@D)/cuda/include/npps_filtering_functions.h" && cp "/usr/local/cuda-8.0/include/cusparse_v2.h" "$(@D)/cuda/include/cusparse_v2.h" && cp "/usr/local/cuda-8.0/include/nppi.h" "$(@D)/cuda/include/nppi.h" && cp "/usr/local/cuda-8.0/include/surface_indirect_functions.h" "$(@D)/cuda/include/surface_indirect_functions.h" && cp "/usr/local/cuda-8.0/include/sm_30_intrinsics.h" "$(@D)/cuda/include/sm_30_intrinsics.h" && cp "/usr/local/cuda-8.0/include/device_double_functions.h" "$(@D)/cuda/include/device_double_functions.h" && cp "/usr/local/cuda-8.0/include/sm_35_intrinsics.h" "$(@D)/cuda/include/sm_35_intrinsics.h" && cp "/usr/local/cuda-8.0/include/cusolverSp.h" "$(@D)/cuda/include/cusolverSp.h" && cp "/usr/local/cuda-8.0/include/library_types.h" "$(@D)/cuda/include/library_types.h" && cp "/usr/local/cuda-8.0/include/surface_indirect_functions.hpp" "$(@D)/cuda/include/surface_indirect_functions.hpp" && cp "/usr/local/cuda-8.0/include/cudalibxt.h" "$(@D)/cuda/include/cudalibxt.h" && cp "/usr/local/cuda-8.0/include/channel_descriptor.h" "$(@D)/cuda/include/channel_descriptor.h" && cp "/usr/local/cuda-8.0/include/device_functions_decls.h" "$(@D)/cuda/include/device_functions_decls.h" && cp "/usr/local/cuda-8.0/include/curand_kernel.h" "$(@D)/cuda/include/curand_kernel.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32_host.h" "$(@D)/cuda/include/curand_mtgp32_host.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtCuda.h" "$(@D)/cuda/include/nvToolsExtCuda.h" && cp "/usr/local/cuda-8.0/include/nvToolsExt.h" "$(@D)/cuda/include/nvToolsExt.h" && cp "/usr/local/cuda-8.0/include/cuComplex.h" "$(@D)/cuda/include/cuComplex.h" && cp "/usr/local/cuda-8.0/include/sm_32_atomic_functions.h" "$(@D)/cuda/include/sm_32_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/texture_indirect_functions.h" "$(@D)/cuda/include/texture_indirect_functions.h" && cp "/usr/local/cuda-8.0/include/sm_32_atomic_functions.hpp" "$(@D)/cuda/include/sm_32_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/sm_20_intrinsics.hpp" "$(@D)/cuda/include/sm_20_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/device_launch_parameters.h" "$(@D)/cuda/include/device_launch_parameters.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32.h" "$(@D)/cuda/include/curand_mtgp32.h" && cp "/usr/local/cuda-8.0/include/texture_fetch_functions.hpp" "$(@D)/cuda/include/texture_fetch_functions.hpp" && cp "/usr/local/cuda-8.0/include/cuda_occupancy.h" "$(@D)/cuda/include/cuda_occupancy.h" && cp "/usr/local/cuda-8.0/include/CL/opencl.h" "$(@D)/cuda/include/CL/opencl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_platform.h" "$(@D)/cuda/include/CL/cl_platform.h" && cp "/usr/local/cuda-8.0/include/CL/cl_egl.h" "$(@D)/cuda/include/CL/cl_egl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_gl.h" "$(@D)/cuda/include/CL/cl_gl.h" && cp "/usr/local/cuda-8.0/include/CL/cl.h" "$(@D)/cuda/include/CL/cl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_gl_ext.h" "$(@D)/cuda/include/CL/cl_gl_ext.h" && cp "/usr/local/cuda-8.0/include/CL/cl_ext.h" "$(@D)/cuda/include/CL/cl_ext.h" && cp "/usr/local/cuda-8.0/include/CL/cl.hpp" "$(@D)/cuda/include/CL/cl.hpp" && cp "/usr/local/cuda-8.0/include/host_config.h" "$(@D)/cuda/include/host_config.h" && cp "/usr/local/cuda-8.0/include/cuda_surface_types.h" "$(@D)/cuda/include/cuda_surface_types.h" && cp "/usr/local/cuda-8.0/include/math_functions.h" "$(@D)/cuda/include/math_functions.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtMeta.h" "$(@D)/cuda/include/nvToolsExtMeta.h" && cp "/usr/local/cuda-8.0/include/sm_20_atomic_functions.hpp" "$(@D)/cuda/include/sm_20_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/device_functions.h" "$(@D)/cuda/include/device_functions.h" && cp "/usr/local/cuda-8.0/include/device_types.h" "$(@D)/cuda/include/device_types.h" && cp "/usr/local/cuda-8.0/include/npps_conversion_functions.h" "$(@D)/cuda/include/npps_conversion_functions.h" && cp "/usr/local/cuda-8.0/include/curand_precalc.h" "$(@D)/cuda/include/curand_precalc.h" && cp "/usr/local/cuda-8.0/include/cusolverRf.h" "$(@D)/cuda/include/cusolverRf.h" && cp "/usr/local/cuda-8.0/include/sm_60_atomic_functions.hpp" "$(@D)/cuda/include/sm_60_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/cuviddec.h" "$(@D)/cuda/include/cuviddec.h" && cp "/usr/local/cuda-8.0/include/curand_discrete2.h" "$(@D)/cuda/include/curand_discrete2.h" && cp "/usr/local/cuda-8.0/include/device_functions.hpp" "$(@D)/cuda/include/device_functions.hpp" && cp "/usr/local/cuda-8.0/include/thrust/transform_scan.h" "$(@D)/cuda/include/thrust/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system_error.h" "$(@D)/cuda/include/thrust/system_error.h" && cp "/usr/local/cuda-8.0/include/thrust/device_malloc.h" "$(@D)/cuda/include/thrust/device_malloc.h" && cp "/usr/local/cuda-8.0/include/thrust/partition.h" "$(@D)/cuda/include/thrust/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/unique.h" "$(@D)/cuda/include/thrust/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/device_delete.h" "$(@D)/cuda/include/thrust/device_delete.h" && cp "/usr/local/cuda-8.0/include/thrust/execution_policy.h" "$(@D)/cuda/include/thrust/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/adjacent_difference.h" "$(@D)/cuda/include/thrust/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/sequence.h" "$(@D)/cuda/include/thrust/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/merge.h" "$(@D)/cuda/include/thrust/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/device_new.h" "$(@D)/cuda/include/thrust/device_new.h" && cp "/usr/local/cuda-8.0/include/thrust/transform_reduce.h" "$(@D)/cuda/include/thrust/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/device_vector.h" "$(@D)/cuda/include/thrust/device_vector.h" && cp "/usr/local/cuda-8.0/include/thrust/gather.h" "$(@D)/cuda/include/thrust/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/sort.h" "$(@D)/cuda/include/thrust/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/scan.h" "$(@D)/cuda/include/thrust/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_array.h" "$(@D)/cuda/include/thrust/detail/temporary_array.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/util/align.h" "$(@D)/cuda/include/thrust/detail/util/align.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/util/blocking.h" "$(@D)/cuda/include/thrust/detail/util/blocking.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform.inl" "$(@D)/cuda/include/thrust/detail/transform.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_vector.inl" "$(@D)/cuda/include/thrust/detail/device_vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/binary_search.inl" "$(@D)/cuda/include/thrust/detail/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/overlapped_copy.h" "$(@D)/cuda/include/thrust/detail/overlapped_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/vector_base.inl" "$(@D)/cuda/include/thrust/detail/vector_base.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_reference.inl" "$(@D)/cuda/include/thrust/detail/device_reference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/actor.h" "$(@D)/cuda/include/thrust/detail/functional/actor.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/value.h" "$(@D)/cuda/include/thrust/detail/functional/value.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/logical_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/logical_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/relational_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/relational_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/assignment_operator.h" "$(@D)/cuda/include/thrust/detail/functional/operators/assignment_operator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/bitwise_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/bitwise_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/operator_adaptors.h" "$(@D)/cuda/include/thrust/detail/functional/operators/operator_adaptors.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/arithmetic_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/arithmetic_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/compound_assignment_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/argument.h" "$(@D)/cuda/include/thrust/detail/functional/argument.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/placeholder.h" "$(@D)/cuda/include/thrust/detail/functional/placeholder.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/actor.inl" "$(@D)/cuda/include/thrust/detail/functional/actor.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/composite.h" "$(@D)/cuda/include/thrust/detail/functional/composite.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/static_map.h" "$(@D)/cuda/include/thrust/detail/static_map.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_nested_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_nested_type.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/is_call_possible.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_call_possible.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/function_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/function_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/pointer_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/pointer_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_member_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_member_function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" "$(@D)/cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/minimum_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/minimum_type.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_trivial_assign.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_trivial_assign.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/is_metafunction_defined.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_metafunction_defined.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/iterator/is_output_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/result_of_adaptable_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference.h" "$(@D)/cuda/include/thrust/detail/reference.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/inner_product.inl" "$(@D)/cuda/include/thrust/detail/inner_product.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/use_default.h" "$(@D)/cuda/include/thrust/detail/use_default.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/sequence.inl" "$(@D)/cuda/include/thrust/detail/sequence.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/sort.inl" "$(@D)/cuda/include/thrust/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/equal.inl" "$(@D)/cuda/include/thrust/detail/equal.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/execution_policy.h" "$(@D)/cuda/include/thrust/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/integer_traits.h" "$(@D)/cuda/include/thrust/detail/integer_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reverse.inl" "$(@D)/cuda/include/thrust/detail/reverse.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tabulate.inl" "$(@D)/cuda/include/thrust/detail/tabulate.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/unique.inl" "$(@D)/cuda/include/thrust/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/scatter.inl" "$(@D)/cuda/include/thrust/detail/scatter.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/set_operations.inl" "$(@D)/cuda/include/thrust/detail/set_operations.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_malloc.inl" "$(@D)/cuda/include/thrust/detail/device_malloc.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy_if.inl" "$(@D)/cuda/include/thrust/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/fill.inl" "$(@D)/cuda/include/thrust/detail/fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_array.inl" "$(@D)/cuda/include/thrust/detail/temporary_array.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform_scan.inl" "$(@D)/cuda/include/thrust/detail/transform_scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/minmax.h" "$(@D)/cuda/include/thrust/detail/minmax.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap.inl" "$(@D)/cuda/include/thrust/detail/swap.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pointer.inl" "$(@D)/cuda/include/thrust/detail/pointer.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform_reduce.inl" "$(@D)/cuda/include/thrust/detail/transform_reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/config.h" "$(@D)/cuda/include/thrust/detail/config.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/distance.inl" "$(@D)/cuda/include/thrust/detail/distance.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pair.inl" "$(@D)/cuda/include/thrust/detail/pair.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/temporary_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/tagged_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/destroy_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/destroy_range.h" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/no_throw_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/no_throw_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/default_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/fill_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/tagged_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/malloc_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/allocator_traits.h" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/copy_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/allocator_traits.inl" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/default_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/copy_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/malloc_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/temporary_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/fill_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reduce.inl" "$(@D)/cuda/include/thrust/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_new.inl" "$(@D)/cuda/include/thrust/detail/device_new.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pointer.h" "$(@D)/cuda/include/thrust/detail/pointer.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/for_each.inl" "$(@D)/cuda/include/thrust/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/generate.inl" "$(@D)/cuda/include/thrust/detail/generate.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/dispatch/is_trivial_copy.h" "$(@D)/cuda/include/thrust/detail/dispatch/is_trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/detail/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple_meta_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_meta_transform.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional.inl" "$(@D)/cuda/include/thrust/detail/functional.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/remove.inl" "$(@D)/cuda/include/thrust/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_transform.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/merge.inl" "$(@D)/cuda/include/thrust/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/extrema.inl" "$(@D)/cuda/include/thrust/detail/extrema.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/trivial_sequence.h" "$(@D)/cuda/include/thrust/detail/trivial_sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/vector_base.h" "$(@D)/cuda/include/thrust/detail/vector_base.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/count.inl" "$(@D)/cuda/include/thrust/detail/count.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/function.h" "$(@D)/cuda/include/thrust/detail/function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap_ranges.inl" "$(@D)/cuda/include/thrust/detail/swap_ranges.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_delete.inl" "$(@D)/cuda/include/thrust/detail/device_delete.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/static_assert.h" "$(@D)/cuda/include/thrust/detail/static_assert.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/logical.inl" "$(@D)/cuda/include/thrust/detail/logical.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/seq.h" "$(@D)/cuda/include/thrust/detail/seq.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/mpl/math.h" "$(@D)/cuda/include/thrust/detail/mpl/math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/mismatch.inl" "$(@D)/cuda/include/thrust/detail/mismatch.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/internal_functional.h" "$(@D)/cuda/include/thrust/detail/internal_functional.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/get_iterator_value.h" "$(@D)/cuda/include/thrust/detail/get_iterator_value.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy.inl" "$(@D)/cuda/include/thrust/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy.h" "$(@D)/cuda/include/thrust/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/catrigf.h" "$(@D)/cuda/include/thrust/detail/complex/catrigf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cpowf.h" "$(@D)/cuda/include/thrust/detail/complex/cpowf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csqrtf.h" "$(@D)/cuda/include/thrust/detail/complex/csqrtf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ccoshf.h" "$(@D)/cuda/include/thrust/detail/complex/ccoshf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csinhf.h" "$(@D)/cuda/include/thrust/detail/complex/csinhf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/clogf.h" "$(@D)/cuda/include/thrust/detail/complex/clogf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ccosh.h" "$(@D)/cuda/include/thrust/detail/complex/ccosh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/arithmetic.h" "$(@D)/cuda/include/thrust/detail/complex/arithmetic.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csqrt.h" "$(@D)/cuda/include/thrust/detail/complex/csqrt.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cpow.h" "$(@D)/cuda/include/thrust/detail/complex/cpow.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/complex.inl" "$(@D)/cuda/include/thrust/detail/complex/complex.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/math_private.h" "$(@D)/cuda/include/thrust/detail/complex/math_private.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/c99math.h" "$(@D)/cuda/include/thrust/detail/complex/c99math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cproj.h" "$(@D)/cuda/include/thrust/detail/complex/cproj.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/catrig.h" "$(@D)/cuda/include/thrust/detail/complex/catrig.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ctanhf.h" "$(@D)/cuda/include/thrust/detail/complex/ctanhf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cexpf.h" "$(@D)/cuda/include/thrust/detail/complex/cexpf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csinh.h" "$(@D)/cuda/include/thrust/detail/complex/csinh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/stream.h" "$(@D)/cuda/include/thrust/detail/complex/stream.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ctanh.h" "$(@D)/cuda/include/thrust/detail/complex/ctanh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cexp.h" "$(@D)/cuda/include/thrust/detail/complex/cexp.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/clog.h" "$(@D)/cuda/include/thrust/detail/complex/clog.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/range/head_flags.h" "$(@D)/cuda/include/thrust/detail/range/head_flags.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/range/tail_flags.h" "$(@D)/cuda/include/thrust/detail/range/tail_flags.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/execute_with_allocator.h" "$(@D)/cuda/include/thrust/detail/execute_with_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/integer_math.h" "$(@D)/cuda/include/thrust/detail/integer_math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap.h" "$(@D)/cuda/include/thrust/detail/swap.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/scan.inl" "$(@D)/cuda/include/thrust/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/gather.inl" "$(@D)/cuda/include/thrust/detail/gather.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference_forward_declaration.h" "$(@D)/cuda/include/thrust/detail/reference_forward_declaration.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/numeric_traits.h" "$(@D)/cuda/include/thrust/detail/numeric_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference.inl" "$(@D)/cuda/include/thrust/detail/reference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/cstdint.h" "$(@D)/cuda/include/thrust/detail/cstdint.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_free.inl" "$(@D)/cuda/include/thrust/detail/device_free.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy_if.h" "$(@D)/cuda/include/thrust/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/partition.inl" "$(@D)/cuda/include/thrust/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/find.inl" "$(@D)/cuda/include/thrust/detail/find.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/forceinline.h" "$(@D)/cuda/include/thrust/detail/config/forceinline.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/debug.h" "$(@D)/cuda/include/thrust/detail/config/debug.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/config.h" "$(@D)/cuda/include/thrust/detail/config/config.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/host_device.h" "$(@D)/cuda/include/thrust/detail/config/host_device.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/host_system.h" "$(@D)/cuda/include/thrust/detail/config/host_system.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/compiler.h" "$(@D)/cuda/include/thrust/detail/config/compiler.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/device_system.h" "$(@D)/cuda/include/thrust/detail/config/device_system.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/compiler_fence.h" "$(@D)/cuda/include/thrust/detail/config/compiler_fence.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/exec_check_disable.h" "$(@D)/cuda/include/thrust/detail/config/exec_check_disable.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/simple_defines.h" "$(@D)/cuda/include/thrust/detail/config/simple_defines.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/global_workarounds.h" "$(@D)/cuda/include/thrust/detail/config/global_workarounds.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/replace.inl" "$(@D)/cuda/include/thrust/detail/replace.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_ptr.inl" "$(@D)/cuda/include/thrust/detail/device_ptr.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple.inl" "$(@D)/cuda/include/thrust/detail/tuple.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/host_vector.inl" "$(@D)/cuda/include/thrust/detail/host_vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/raw_pointer_cast.h" "$(@D)/cuda/include/thrust/detail/raw_pointer_cast.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/advance.inl" "$(@D)/cuda/include/thrust/detail/advance.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/contiguous_storage.h" "$(@D)/cuda/include/thrust/detail/contiguous_storage.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/raw_reference_cast.h" "$(@D)/cuda/include/thrust/detail/raw_reference_cast.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/contiguous_storage.inl" "$(@D)/cuda/include/thrust/detail/contiguous_storage.inl" && cp "/usr/local/cuda-8.0/include/thrust/reverse.h" "$(@D)/cuda/include/thrust/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/device_malloc_allocator.h" "$(@D)/cuda/include/thrust/device_malloc_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/scatter.h" "$(@D)/cuda/include/thrust/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/pair.h" "$(@D)/cuda/include/thrust/pair.h" && cp "/usr/local/cuda-8.0/include/thrust/advance.h" "$(@D)/cuda/include/thrust/advance.h" && cp "/usr/local/cuda-8.0/include/thrust/find.h" "$(@D)/cuda/include/thrust/find.h" && cp "/usr/local/cuda-8.0/include/thrust/device_ptr.h" "$(@D)/cuda/include/thrust/device_ptr.h" && cp "/usr/local/cuda-8.0/include/thrust/generate.h" "$(@D)/cuda/include/thrust/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/uninitialized_fill.h" "$(@D)/cuda/include/thrust/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/system_error.h" "$(@D)/cuda/include/thrust/system/system_error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/bad_alloc.h" "$(@D)/cuda/include/thrust/system/detail/bad_alloc.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/partition.h" "$(@D)/cuda/include/thrust/system/detail/adl/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/unique.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/adl/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/sequence.h" "$(@D)/cuda/include/thrust/system/detail/adl/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/merge.h" "$(@D)/cuda/include/thrust/system/detail/adl/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/gather.h" "$(@D)/cuda/include/thrust/system/detail/adl/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/sort.h" "$(@D)/cuda/include/thrust/system/detail/adl/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/adl/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reverse.h" "$(@D)/cuda/include/thrust/system/detail/adl/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scatter.h" "$(@D)/cuda/include/thrust/system/detail/adl/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/find.h" "$(@D)/cuda/include/thrust/system/detail/adl/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/generate.h" "$(@D)/cuda/include/thrust/system/detail/adl/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/remove.h" "$(@D)/cuda/include/thrust/system/detail/adl/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/adl/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/for_each.h" "$(@D)/cuda/include/thrust/system/detail/adl/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/equal.h" "$(@D)/cuda/include/thrust/system/detail/adl/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/adl/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/adl/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/adl/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/adl/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/extrema.h" "$(@D)/cuda/include/thrust/system/detail/adl/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/count.h" "$(@D)/cuda/include/thrust/system/detail/adl/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/replace.h" "$(@D)/cuda/include/thrust/system/detail/adl/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/get_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/adl/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/logical.h" "$(@D)/cuda/include/thrust/system/detail/adl/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/adl/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/adl/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/errno.h" "$(@D)/cuda/include/thrust/system/detail/errno.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_category.inl" "$(@D)/cuda/include/thrust/system/detail/error_category.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/partition.h" "$(@D)/cuda/include/thrust/system/detail/sequential/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/unique.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/execution_policy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/sequential/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sequence.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/merge.h" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/gather.h" "$(@D)/cuda/include/thrust/system/detail/sequential/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy_backward.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_backward.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/sequential/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reverse.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scatter.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/find.h" "$(@D)/cuda/include/thrust/system/detail/sequential/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/merge.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/generate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/general_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/general_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/insertion_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/insertion_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/remove.h" "$(@D)/cuda/include/thrust/system/detail/sequential/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/for_each.h" "$(@D)/cuda/include/thrust/system/detail/sequential/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/equal.h" "$(@D)/cuda/include/thrust/system/detail/sequential/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/sequential/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/sequential/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/sequential/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/sequential/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/extrema.h" "$(@D)/cuda/include/thrust/system/detail/sequential/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/count.h" "$(@D)/cuda/include/thrust/system/detail/sequential/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/trivial_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/replace.h" "$(@D)/cuda/include/thrust/system/detail/sequential/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/get_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/sequential/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/logical.h" "$(@D)/cuda/include/thrust/system/detail/sequential/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/sequential/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/sequential/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_condition.inl" "$(@D)/cuda/include/thrust/system/detail/error_condition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/internal/decompose.h" "$(@D)/cuda/include/thrust/system/detail/internal/decompose.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_code.inl" "$(@D)/cuda/include/thrust/system/detail/error_code.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/memory.inl" "$(@D)/cuda/include/thrust/system/detail/generic/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/inner_product.inl" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/select_system.h" "$(@D)/cuda/include/thrust/system/detail/generic/select_system.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sequence.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sort.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/equal.inl" "$(@D)/cuda/include/thrust/system/detail/generic/equal.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/partition.h" "$(@D)/cuda/include/thrust/system/detail/generic/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tag.h" "$(@D)/cuda/include/thrust/system/detail/generic/tag.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sequence.h" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/type_traits.h" "$(@D)/cuda/include/thrust/system/detail/generic/type_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/merge.h" "$(@D)/cuda/include/thrust/system/detail/generic/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reverse.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tabulate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scatter.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/set_operations.inl" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy_if.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/gather.h" "$(@D)/cuda/include/thrust/system/detail/generic/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sort.h" "$(@D)/cuda/include/thrust/system/detail/generic/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/distance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/distance.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reverse.h" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/temporary_buffer.inl" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scatter.h" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/generate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/generate.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/remove.inl" "$(@D)/cuda/include/thrust/system/detail/generic/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/advance.h" "$(@D)/cuda/include/thrust/system/detail/generic/advance.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/find.h" "$(@D)/cuda/include/thrust/system/detail/generic/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/merge.inl" "$(@D)/cuda/include/thrust/system/detail/generic/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scalar/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scalar/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/extrema.inl" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/generate.h" "$(@D)/cuda/include/thrust/system/detail/generic/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/count.inl" "$(@D)/cuda/include/thrust/system/detail/generic/count.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/remove.h" "$(@D)/cuda/include/thrust/system/detail/generic/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/for_each.h" "$(@D)/cuda/include/thrust/system/detail/generic/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/distance.h" "$(@D)/cuda/include/thrust/system/detail/generic/distance.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/swap_ranges.inl" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/equal.h" "$(@D)/cuda/include/thrust/system/detail/generic/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/mismatch.inl" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/gather.inl" "$(@D)/cuda/include/thrust/system/detail/generic/gather.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/extrema.h" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/count.h" "$(@D)/cuda/include/thrust/system/detail/generic/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/replace.h" "$(@D)/cuda/include/thrust/system/detail/generic/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/logical.h" "$(@D)/cuda/include/thrust/system/detail/generic/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/partition.inl" "$(@D)/cuda/include/thrust/system/detail/generic/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/memory.h" "$(@D)/cuda/include/thrust/system/detail/generic/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/find.inl" "$(@D)/cuda/include/thrust/system/detail/generic/find.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/replace.inl" "$(@D)/cuda/include/thrust/system/detail/generic/replace.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/advance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/advance.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/system_error.inl" "$(@D)/cuda/include/thrust/system/detail/system_error.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/vector.h" "$(@D)/cuda/include/thrust/system/omp/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/omp/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sort.inl" "$(@D)/cuda/include/thrust/system/omp/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/partition.h" "$(@D)/cuda/include/thrust/system/omp/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/omp/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/omp/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/merge.h" "$(@D)/cuda/include/thrust/system/omp/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/gather.h" "$(@D)/cuda/include/thrust/system/omp/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sort.h" "$(@D)/cuda/include/thrust/system/omp/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/omp/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/omp/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/omp/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/remove.inl" "$(@D)/cuda/include/thrust/system/omp/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/omp/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/find.h" "$(@D)/cuda/include/thrust/system/omp/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/generate.h" "$(@D)/cuda/include/thrust/system/omp/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/remove.h" "$(@D)/cuda/include/thrust/system/omp/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/omp/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/equal.h" "$(@D)/cuda/include/thrust/system/omp/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/omp/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/omp/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/omp/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/omp/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/omp/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/count.h" "$(@D)/cuda/include/thrust/system/omp/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/replace.h" "$(@D)/cuda/include/thrust/system/omp/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/omp/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/logical.h" "$(@D)/cuda/include/thrust/system/omp/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/partition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/omp/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/par.h" "$(@D)/cuda/include/thrust/system/omp/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/omp/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/memory.h" "$(@D)/cuda/include/thrust/system/omp/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/vector.h" "$(@D)/cuda/include/thrust/system/tbb/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/memory.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sort.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/partition.h" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/tbb/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sequence.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/merge.h" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/gather.h" "$(@D)/cuda/include/thrust/system/tbb/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sort.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/tbb/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reverse.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scatter.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/remove.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/vector.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/find.h" "$(@D)/cuda/include/thrust/system/tbb/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/merge.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/generate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/remove.h" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/for_each.h" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/equal.h" "$(@D)/cuda/include/thrust/system/tbb/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/tbb/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/tbb/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/tbb/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/tbb/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/extrema.h" "$(@D)/cuda/include/thrust/system/tbb/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/count.h" "$(@D)/cuda/include/thrust/system/tbb/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/replace.h" "$(@D)/cuda/include/thrust/system/tbb/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/get_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/tbb/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/logical.h" "$(@D)/cuda/include/thrust/system/tbb/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/partition.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/tbb/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/par.h" "$(@D)/cuda/include/thrust/system/tbb/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/tbb/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/memory.h" "$(@D)/cuda/include/thrust/system/tbb/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/error_code.h" "$(@D)/cuda/include/thrust/system/error_code.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/vector.h" "$(@D)/cuda/include/thrust/system/cpp/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/partition.h" "$(@D)/cuda/include/thrust/system/cpp/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/unique.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cpp/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/merge.h" "$(@D)/cuda/include/thrust/system/cpp/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/gather.h" "$(@D)/cuda/include/thrust/system/cpp/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/sort.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cpp/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/find.h" "$(@D)/cuda/include/thrust/system/cpp/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/generate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/remove.h" "$(@D)/cuda/include/thrust/system/cpp/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cpp/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/equal.h" "$(@D)/cuda/include/thrust/system/cpp/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cpp/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cpp/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cpp/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cpp/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cpp/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/count.h" "$(@D)/cuda/include/thrust/system/cpp/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/replace.h" "$(@D)/cuda/include/thrust/system/cpp/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cpp/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/logical.h" "$(@D)/cuda/include/thrust/system/cpp/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cpp/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/par.h" "$(@D)/cuda/include/thrust/system/cpp/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cpp/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/memory.h" "$(@D)/cuda/include/thrust/system/cpp/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/vector.h" "$(@D)/cuda/include/thrust/system/cuda/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/error.h" "$(@D)/cuda/include/thrust/system/cuda/error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_device_to_device.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_device_to_device.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_allocator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_device.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_device.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_macro.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_namespace.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_type.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_type.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/host/spinlock.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/host/spinlock.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_ptx.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_debug.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/cub.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/cub.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_shift.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_shift.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_arch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_cross_system.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_cross_system.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk.h" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/partition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/unique.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cuda_launch_config.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cuda_launch_config.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cub.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_symmetric_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_symmetric_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/error.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/error.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/gather.h" "$(@D)/cuda/include/thrust/system/cuda/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/synchronize.h" "$(@D)/cuda/include/thrust/system/cuda/detail/synchronize.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/temporary_indirect_permutation.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_indirect_permutation.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/extern_shared_ptr.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extern_shared_ptr.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/set_operation.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/set_operation.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/balanced_path.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/balanced_path.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/set_operation.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/set_operation.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_closure.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_closure.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/alignment.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/alignment.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_sort_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_calculator.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_calculator.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_closure.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_closure.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/uninitialized.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/uninitialized.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_calculator.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_calculator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_sort_each.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/default_decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/default_decomposition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/throw_on_error.h" "$(@D)/cuda/include/thrust/system/cuda/detail/throw_on_error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/find.h" "$(@D)/cuda/include/thrust/system/cuda/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/terminate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/terminate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/merge.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/trivial_copy.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/trivial_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/generate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/execute_on_stream.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execute_on_stream.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/remove.h" "$(@D)/cuda/include/thrust/system/cuda/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/decomposition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/equal.h" "$(@D)/cuda/include/thrust/system/cuda/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/runtime_introspection.h" "$(@D)/cuda/include/thrust/system/cuda/detail/runtime_introspection.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cuda/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cuda/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/runtime_introspection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/runtime_introspection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cuda/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/synchronize.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/synchronize.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_union.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_union.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_intersection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_intersection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/count.h" "$(@D)/cuda/include/thrust/system/cuda/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/trivial_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_device_to_device.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_device_to_device.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/replace.h" "$(@D)/cuda/include/thrust/system/cuda/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/malloc.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/malloc.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/config.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/config.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/closure.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/closure.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/async.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/async.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/bulk.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/bulk.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/execution_policy.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/execution_policy.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/uninitialized.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/uninitialized.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/async.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/async.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/future.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/future.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/guarded_driver_types.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_driver_types.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cuda/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/logical.h" "$(@D)/cuda/include/thrust/system/cuda/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cuda/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/inclusive_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/inclusive_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merge.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merging_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merging_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/exclusive_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/exclusive_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/odd_even_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/odd_even_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/par.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_cross_system.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cuda/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/experimental/pinned_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/experimental/pinned_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/memory.h" "$(@D)/cuda/include/thrust/system/cuda/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/remove.h" "$(@D)/cuda/include/thrust/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/tabulate.h" "$(@D)/cuda/include/thrust/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/for_each.h" "$(@D)/cuda/include/thrust/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/distance.h" "$(@D)/cuda/include/thrust/distance.h" && cp "/usr/local/cuda-8.0/include/thrust/reduce.h" "$(@D)/cuda/include/thrust/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/equal.h" "$(@D)/cuda/include/thrust/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/complex.h" "$(@D)/cuda/include/thrust/complex.h" && cp "/usr/local/cuda-8.0/include/thrust/device_allocator.h" "$(@D)/cuda/include/thrust/device_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/copy.h" "$(@D)/cuda/include/thrust/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/uninitialized_copy.h" "$(@D)/cuda/include/thrust/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/device_reference.h" "$(@D)/cuda/include/thrust/device_reference.h" && cp "/usr/local/cuda-8.0/include/thrust/binary_search.h" "$(@D)/cuda/include/thrust/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/set_operations.h" "$(@D)/cuda/include/thrust/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/swap.h" "$(@D)/cuda/include/thrust/swap.h" && cp "/usr/local/cuda-8.0/include/thrust/mismatch.h" "$(@D)/cuda/include/thrust/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/extrema.h" "$(@D)/cuda/include/thrust/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/count.h" "$(@D)/cuda/include/thrust/count.h" && cp "/usr/local/cuda-8.0/include/thrust/device_free.h" "$(@D)/cuda/include/thrust/device_free.h" && cp "/usr/local/cuda-8.0/include/thrust/random/discard_block_engine.h" "$(@D)/cuda/include/thrust/random/discard_block_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/normal_distribution.h" "$(@D)/cuda/include/thrust/random/normal_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/subtract_with_carry_engine.inl" "$(@D)/cuda/include/thrust/random/detail/subtract_with_carry_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/xor_combine_engine_max.h" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine_max.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_congruential_engine_discard.h" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine_discard.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/uniform_int_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_int_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/discard_block_engine.inl" "$(@D)/cuda/include/thrust/random/detail/discard_block_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/uniform_real_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_real_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/random_core_access.h" "$(@D)/cuda/include/thrust/random/detail/random_core_access.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/mod.h" "$(@D)/cuda/include/thrust/random/detail/mod.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_feedback_shift_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_congruential_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/xor_combine_engine.inl" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/normal_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/normal_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/normal_distribution_base.h" "$(@D)/cuda/include/thrust/random/detail/normal_distribution_base.h" && cp "/usr/local/cuda-8.0/include/thrust/random/uniform_int_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_int_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/random/linear_feedback_shift_engine.h" "$(@D)/cuda/include/thrust/random/linear_feedback_shift_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/xor_combine_engine.h" "$(@D)/cuda/include/thrust/random/xor_combine_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/subtract_with_carry_engine.h" "$(@D)/cuda/include/thrust/random/subtract_with_carry_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/linear_congruential_engine.h" "$(@D)/cuda/include/thrust/random/linear_congruential_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/uniform_real_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_real_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/functional.h" "$(@D)/cuda/include/thrust/functional.h" && cp "/usr/local/cuda-8.0/include/thrust/replace.h" "$(@D)/cuda/include/thrust/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/device_new_allocator.h" "$(@D)/cuda/include/thrust/device_new_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/host_vector.h" "$(@D)/cuda/include/thrust/host_vector.h" && cp "/usr/local/cuda-8.0/include/thrust/version.h" "$(@D)/cuda/include/thrust/version.h" && cp "/usr/local/cuda-8.0/include/thrust/inner_product.h" "$(@D)/cuda/include/thrust/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_traits.h" "$(@D)/cuda/include/thrust/iterator/iterator_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/discard_iterator.h" "$(@D)/cuda/include/thrust/iterator/discard_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/retag.h" "$(@D)/cuda/include/thrust/iterator/retag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/permutation_iterator.h" "$(@D)/cuda/include/thrust/iterator/permutation_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/transform_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/reverse_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/zip_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/counting_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/counting_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/distance_from_result.h" "$(@D)/cuda/include/thrust/iterator/detail/distance_from_result.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/host_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/host_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_traversal_tags.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traversal_tags.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/retag.h" "$(@D)/cuda/include/thrust/iterator/detail/retag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/tagged_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/tagged_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_traits.inl" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traits.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/minimum_category.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/discard_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/discard_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_to_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/zip_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/normal_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/normal_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/join_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/join_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/device_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/device_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/universal_categories.h" "$(@D)/cuda/include/thrust/iterator/detail/universal_categories.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/reverse_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/minimum_system.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_system.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/tuple_of_iterator_references.h" "$(@D)/cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/is_iterator_category.h" "$(@D)/cuda/include/thrust/iterator/detail/is_iterator_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/permutation_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/permutation_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/any_assign.h" "$(@D)/cuda/include/thrust/iterator/detail/any_assign.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/any_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/any_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/is_trivial_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/is_trivial_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_to_system.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_system.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_adaptor_base.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_adaptor_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/constant_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/constant_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/transform_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_facade_category.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_facade_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/constant_iterator.h" "$(@D)/cuda/include/thrust/iterator/constant_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/counting_iterator.h" "$(@D)/cuda/include/thrust/iterator/counting_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_adaptor.h" "$(@D)/cuda/include/thrust/iterator/iterator_adaptor.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_facade.h" "$(@D)/cuda/include/thrust/iterator/iterator_facade.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_categories.h" "$(@D)/cuda/include/thrust/iterator/iterator_categories.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/reverse_iterator.h" "$(@D)/cuda/include/thrust/iterator/reverse_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/zip_iterator.h" "$(@D)/cuda/include/thrust/iterator/zip_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/logical.h" "$(@D)/cuda/include/thrust/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/tuple.h" "$(@D)/cuda/include/thrust/tuple.h" && cp "/usr/local/cuda-8.0/include/thrust/memory.h" "$(@D)/cuda/include/thrust/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/random.h" "$(@D)/cuda/include/thrust/random.h" && cp "/usr/local/cuda-8.0/include/thrust/fill.h" "$(@D)/cuda/include/thrust/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/transform.h" "$(@D)/cuda/include/thrust/transform.h" && cp "/usr/local/cuda-8.0/include/texture_types.h" "$(@D)/cuda/include/texture_types.h" && cp "/usr/local/cuda-8.0/include/nppversion.h" "$(@D)/cuda/include/nppversion.h" && cp "/usr/local/cuda-8.0/include/cuda_texture_types.h" "$(@D)/cuda/include/cuda_texture_types.h" && cp "/usr/local/cuda-8.0/include/fatbinary.h" "$(@D)/cuda/include/fatbinary.h" && cp "/usr/local/cuda-8.0/include/cublasXt.h" "$(@D)/cuda/include/cublasXt.h" && cp "/usr/local/cuda-8.0/include/cuda_fp16.h" "$(@D)/cuda/include/cuda_fp16.h" && cp "/usr/local/cuda-8.0/include/vector_functions.h" "$(@D)/cuda/include/vector_functions.h" && cp "/usr/local/cuda-8.0/include/cusparse.h" "$(@D)/cuda/include/cusparse.h" && cp "/usr/local/cuda-8.0/include/nppi_filtering_functions.h" "$(@D)/cuda/include/nppi_filtering_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_morphological_operations.h" "$(@D)/cuda/include/nppi_morphological_operations.h" && cp "/usr/local/cuda-8.0/include/sobol_direction_vectors.h" "$(@D)/cuda/include/sobol_direction_vectors.h" && cp "/usr/local/cuda-8.0/include/nvblas.h" "$(@D)/cuda/include/nvblas.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32dc_p_11213.h" "$(@D)/cuda/include/curand_mtgp32dc_p_11213.h" && cp "/usr/local/cuda-8.0/include/nvcuvid.h" "$(@D)/cuda/include/nvcuvid.h" && cp "/usr/local/cuda-8.0/include/cuda_runtime_api.h" "$(@D)/cuda/include/cuda_runtime_api.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32_kernel.h" "$(@D)/cuda/include/curand_mtgp32_kernel.h" && cp "/usr/local/cuda-8.0/include/cublas_v2.h" "$(@D)/cuda/include/cublas_v2.h" && cp "/usr/local/cuda-8.0/include/builtin_types.h" "$(@D)/cuda/include/builtin_types.h" && cp "/usr/local/cuda-8.0/include/nppi_geometry_transforms.h" "$(@D)/cuda/include/nppi_geometry_transforms.h" && cp "/usr/local/cuda-8.0/include/npps_support_functions.h" "$(@D)/cuda/include/npps_support_functions.h" && cp "/usr/local/cuda-8.0/include/cufftw.h" "$(@D)/cuda/include/cufftw.h" && cp "/usr/local/cuda-8.0/include/cuda_device_runtime_api.h" "$(@D)/cuda/include/cuda_device_runtime_api.h" && cp "/usr/local/cuda-8.0/include/sm_30_intrinsics.hpp" "$(@D)/cuda/include/sm_30_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/vector_types.h" "$(@D)/cuda/include/vector_types.h" && cp "/usr/local/cuda-8.0/include/sm_35_atomic_functions.h" "$(@D)/cuda/include/sm_35_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/sm_20_intrinsics.h" "$(@D)/cuda/include/sm_20_intrinsics.h" && cp "/usr/local/cuda-8.0/include/driver_types.h" "$(@D)/cuda/include/driver_types.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtCudaRt.h" "$(@D)/cuda/include/nvToolsExtCudaRt.h" && cp "/usr/local/cuda-8.0/include/curand_globals.h" "$(@D)/cuda/include/curand_globals.h" && cp "/usr/local/cuda-8.0/include/device_atomic_functions.h" "$(@D)/cuda/include/device_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/surface_types.h" "$(@D)/cuda/include/surface_types.h" && cp "/usr/local/cuda-8.0/include/nvrtc.h" "$(@D)/cuda/include/nvrtc.h" && cp "/usr/local/cuda-8.0/include/nppdefs.h" "$(@D)/cuda/include/nppdefs.h" && cp "/usr/local/cuda-8.0/include/sm_60_atomic_functions.h" "$(@D)/cuda/include/sm_60_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/driver_functions.h" "$(@D)/cuda/include/driver_functions.h" && cp "/usr/local/cuda-8.0/include/cusolver_common.h" "$(@D)/cuda/include/cusolver_common.h" && cp "/usr/local/cuda-8.0/include/cublas.h" "$(@D)/cuda/include/cublas.h" && cp "/usr/local/cuda-8.0/include/curand_lognormal.h" "$(@D)/cuda/include/curand_lognormal.h" && cp "/usr/local/cuda-8.0/include/device_atomic_functions.hpp" "$(@D)/cuda/include/device_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/crt/device_runtime.h" "$(@D)/cuda/include/crt/device_runtime.h" && cp "/usr/local/cuda-8.0/include/crt/storage_class.h" "$(@D)/cuda/include/crt/storage_class.h" && cp "/usr/local/cuda-8.0/include/crt/func_macro.h" "$(@D)/cuda/include/crt/func_macro.h" && cp "/usr/local/cuda-8.0/include/crt/host_runtime.h" "$(@D)/cuda/include/crt/host_runtime.h" && cp "/usr/local/cuda-8.0/include/nppi_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/nppi_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-8.0/include/npps_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/npps_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-8.0/include/nppi_computer_vision.h" "$(@D)/cuda/include/nppi_computer_vision.h" && cp "/usr/local/cuda-8.0/include/surface_functions.hpp" "$(@D)/cuda/include/surface_functions.hpp" && cp "/usr/local/cuda-8.0/include/surface_functions.h" "$(@D)/cuda/include/surface_functions.h" && cp "/usr/local/cuda-8.0/include/curand_normal_static.h" "$(@D)/cuda/include/curand_normal_static.h" && cp "/usr/local/cuda-8.0/include/curand.h" "$(@D)/cuda/include/curand.h" && cp "/usr/local/cuda-8.0/include/math_functions_dbl_ptx3.h" "$(@D)/cuda/include/math_functions_dbl_ptx3.h" && cp "/usr/local/cuda-8.0/include/curand_philox4x32_x.h" "$(@D)/cuda/include/curand_philox4x32_x.h" && cp "/usr/local/cuda-8.0/include/nppi_threshold_and_compare_operations.h" "$(@D)/cuda/include/nppi_threshold_and_compare_operations.h" && cp "/usr/local/cuda-8.0/include/nvml.h" "$(@D)/cuda/include/nvml.h" && cp "/usr/local/cuda-8.0/include/npps.h" "$(@D)/cuda/include/npps.h" && cp "/usr/local/cuda-8.0/include/cuda_vdpau_interop.h" "$(@D)/cuda/include/cuda_vdpau_interop.h" && cp "/usr/local/cuda-8.0/include/sm_61_intrinsics.hpp" "$(@D)/cuda/include/sm_61_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/cublas_api.h" "$(@D)/cuda/include/cublas_api.h" && cp "/usr/local/cuda-8.0/include/nppi_color_conversion.h" "$(@D)/cuda/include/nppi_color_conversion.h" && cp "/usr/local/cuda-8.0/include/math_functions_dbl_ptx3.hpp" "$(@D)/cuda/include/math_functions_dbl_ptx3.hpp" && cp "/usr/local/cuda-8.0/include/nppcore.h" "$(@D)/cuda/include/nppcore.h" && cp "/usr/local/cuda-8.0/include/cudaGL.h" "$(@D)/cuda/include/cudaGL.h" && cp "/usr/local/cuda-8.0/include/fatBinaryCtl.h" "$(@D)/cuda/include/fatBinaryCtl.h" && cp "/usr/local/cuda-8.0/include/npps_statistics_functions.h" "$(@D)/cuda/include/npps_statistics_functions.h" && cp "/usr/local/cuda-8.0/include/cudaVDPAU.h" "$(@D)/cuda/include/cudaVDPAU.h" && cp "/usr/local/cuda-8.0/include/curand_poisson.h" "$(@D)/cuda/include/curand_poisson.h" && cp "/usr/local/cuda-8.0/include/cusolverDn.h" "$(@D)/cuda/include/cusolverDn.h" && cp "/usr/local/cuda-8.0/include/cuda_profiler_api.h" "$(@D)/cuda/include/cuda_profiler_api.h" && cp "/usr/local/cuda-8.0/include/sm_20_atomic_functions.h" "$(@D)/cuda/include/sm_20_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/nvfunctional" "$(@D)/cuda/include/nvfunctional" + """, +) + +genrule( + name = "cuda-nvvm", + outs = [ + "cuda/nvvm/bin/cicc", + "cuda/nvvm/libdevice/libdevice.compute_50.10.bc", + "cuda/nvvm/libdevice/libdevice.compute_30.10.bc", + "cuda/nvvm/libdevice/libdevice.compute_20.10.bc", + "cuda/nvvm/libdevice/libdevice.compute_35.10.bc", + "cuda/nvvm/lib64/libnvvm.so.3", + "cuda/nvvm/lib64/libnvvm.so", + "cuda/nvvm/lib64/libnvvm.so.3.1.0", + "cuda/nvvm/include/nvvm.h", + "cuda/nvvm/libnvvm-samples/ptxgen/README.txt", + "cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c", + "cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/build.bat", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp", + "cuda/nvvm/libnvvm-samples/README.txt", + "cuda/nvvm/libnvvm-samples/simple/simple.c", + "cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll", + "cuda/nvvm/libnvvm-samples/simple/README.txt", + "cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll", + "cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h", + "cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h", + "cuda/nvvm/libnvvm-samples/build.sh", + "cuda/nvvm/libnvvm-samples/CMakeLists.txt", + ], + cmd = """ +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-8.0/nvvm/bin/cicc" "$(@D)/cuda/nvvm/bin/cicc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_50.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_50.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_30.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_30.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_20.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_20.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_35.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_35.10.bc" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so.3" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so" "$(@D)/cuda/nvvm/lib64/libnvvm.so" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so.3.1.0" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3.1.0" && cp "/usr/local/cuda-8.0/nvvm/include/nvvm.h" "$(@D)/cuda/nvvm/include/nvvm.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/ptxgen.c" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/build.bat" "$(@D)/cuda/nvvm/libnvvm-samples/build.bat" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple.c" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple.c" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple-gpu.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple-gpu64.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/common/include/DDSWriter.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/build.sh" "$(@D)/cuda/nvvm/libnvvm-samples/build.sh" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/CMakeLists.txt" + """, +) + +genrule( + name = "cuda-extras", + outs = [ + "cuda/extras/CUPTI/include/cupti_result.h", + "cuda/extras/CUPTI/include/cupti_events.h", + "cuda/extras/CUPTI/include/openacc/cupti_openacc.h", + "cuda/extras/CUPTI/include/cupti_version.h", + "cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h", + "cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h", + "cuda/extras/CUPTI/include/cupti_activity.h", + "cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h", + "cuda/extras/CUPTI/include/generated_cuda_meta.h", + "cuda/extras/CUPTI/include/cupti_nvtx_cbid.h", + "cuda/extras/CUPTI/include/cuda_stdint.h", + "cuda/extras/CUPTI/include/generated_cudaGL_meta.h", + "cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h", + "cuda/extras/CUPTI/include/cupti_metrics.h", + "cuda/extras/CUPTI/include/cupti_callbacks.h", + "cuda/extras/CUPTI/include/cupti_runtime_cbid.h", + "cuda/extras/CUPTI/include/cupti.h", + "cuda/extras/CUPTI/include/GL/glut.h", + "cuda/extras/CUPTI/include/GL/glu.h", + "cuda/extras/CUPTI/include/GL/glxext.h", + "cuda/extras/CUPTI/include/GL/wglext.h", + "cuda/extras/CUPTI/include/GL/glx.h", + "cuda/extras/CUPTI/include/GL/glext.h", + "cuda/extras/CUPTI/include/GL/wglew.h", + "cuda/extras/CUPTI/include/GL/gl.h", + "cuda/extras/CUPTI/include/GL/glew.h", + "cuda/extras/CUPTI/include/cupti_driver_cbid.h", + "cuda/extras/CUPTI/include/generated_nvtx_meta.h", + ], + cmd = """ +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_result.h" "$(@D)/cuda/extras/CUPTI/include/cupti_result.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_events.h" "$(@D)/cuda/extras/CUPTI/include/cupti_events.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/openacc/cupti_openacc.h" "$(@D)/cuda/extras/CUPTI/include/openacc/cupti_openacc.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_version.h" "$(@D)/cuda/extras/CUPTI/include/cupti_version.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cudaVDPAU_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_activity.h" "$(@D)/cuda/extras/CUPTI/include/cupti_activity.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_nvtx_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_nvtx_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cuda_stdint.h" "$(@D)/cuda/extras/CUPTI/include/cuda_stdint.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cudaGL_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaGL_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_metrics.h" "$(@D)/cuda/extras/CUPTI/include/cupti_metrics.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_callbacks.h" "$(@D)/cuda/extras/CUPTI/include/cupti_callbacks.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_runtime_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_runtime_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti.h" "$(@D)/cuda/extras/CUPTI/include/cupti.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glut.h" "$(@D)/cuda/extras/CUPTI/include/GL/glut.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glu.h" "$(@D)/cuda/extras/CUPTI/include/GL/glu.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glxext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glxext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/wglext.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glx.h" "$(@D)/cuda/extras/CUPTI/include/GL/glx.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/wglew.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglew.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/gl.h" "$(@D)/cuda/extras/CUPTI/include/GL/gl.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glew.h" "$(@D)/cuda/extras/CUPTI/include/GL/glew.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_driver_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_driver_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_nvtx_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_nvtx_meta.h" + """, +) + +genrule( + name = "cuda-lib", + outs = [ + "cuda/lib/libcuda.so", + "cuda/lib/libcudart.so.8.0", + "cuda/lib/libcudart_static.a", + "cuda/lib/libcublas.so.8.0", + "cuda/lib/libcusolver.so.8.0", + "cuda/lib/libcurand.so.8.0", + "cuda/lib/libcufft.so.8.0", + "cuda/lib/libcudnn.so.6", + "cuda/lib/libcupti.so.8.0", + ], + cmd = """ +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart.so.8.0.61" "$(@D)/cuda/lib/libcudart.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcublas.so.8.0.71" "$(@D)/cuda/lib/libcublas.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcusolver.so.8.0.61" "$(@D)/cuda/lib/libcusolver.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcurand.so.8.0.61" "$(@D)/cuda/lib/libcurand.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcufft.so.8.0.61" "$(@D)/cuda/lib/libcufft.so.8.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21" "$(@D)/cuda/lib/libcudnn.so.6" && cp "/usr/local/cuda-8.0/extras/CUPTI/lib64/libcupti.so.8.0.61" "$(@D)/cuda/lib/libcupti.so.8.0" + """, +) + +genrule( + name = "cudnn-include", + outs = [ + "cuda/include/cudnn.h", + ], + cmd = """ +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/include/cudnn.h" "$(@D)/cudnn.h" + """, +) diff --git a/third_party/toolchains/gpus/cuda/build_defs.bzl b/third_party/toolchains/gpus/cuda/build_defs.bzl new file mode 100644 index 0000000000000000000000000000000000000000..badaf4301934cb6c87cfecbacf0b3bdfff443fe4 --- /dev/null +++ b/third_party/toolchains/gpus/cuda/build_defs.bzl @@ -0,0 +1,37 @@ +# Macros for building CUDA code used with Bazel remote +# execution service. +# DO NOT EDIT: automatically generated file + +def if_cuda(if_true, if_false = []): + """Shorthand for select()'ing on whether we're building with CUDA. + + Returns a select statement which evaluates to if_true if we're building + with CUDA enabled. Otherwise, the select statement evaluates to if_false. + + """ + return select({ + "@local_config_cuda//cuda:using_nvcc": if_true, + "@local_config_cuda//cuda:using_clang": if_true, + "//conditions:default": if_false + }) + + +def cuda_default_copts(): + """Default options for all CUDA compilations.""" + return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + ["--cuda-gpu-arch=sm_30"]) + + +def cuda_is_configured(): + """Returns true if CUDA was enabled during the configure process.""" + return True + +def if_cuda_is_configured(x): + """Tests if the CUDA was enabled during the configure process. + + Unlike if_cuda(), this does not require that we are building with + --config=cuda. Used to allow non-CUDA code to depend on CUDA libraries. + """ + if cuda_is_configured(): + return x + return [] + diff --git a/tensorflow/tensorboard/components/tf_graph_common/test/layout-test.ts b/third_party/toolchains/gpus/cuda/cuda/cuda_config.h similarity index 56% rename from tensorflow/tensorboard/components/tf_graph_common/test/layout-test.ts rename to third_party/toolchains/gpus/cuda/cuda/cuda_config.h index b4884413c9d4f0b2e3d61d283736174f6549819b..f6662274cc0a31073adbd9a976a42af93f200cfd 100644 --- a/tensorflow/tensorboard/components/tf_graph_common/test/layout-test.ts +++ b/third_party/toolchains/gpus/cuda/cuda/cuda_config.h @@ -1,23 +1,27 @@ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the 'License'); +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -suite('layout', () => { - let assert = chai.assert; +// DO NOT EDIT: automatically generated file +#ifndef CUDA_CUDA_CONFIG_H_ +#define CUDA_CUDA_CONFIG_H_ - test('dagre exists', () => { assert.isTrue(dagre != null); }); +#define TF_CUDA_CAPABILITIES CudaVersion("3.0") - // TODO(bp): write tests. +#define TF_CUDA_VERSION "8.0" +#define TF_CUDNN_VERSION "5" -}); +#define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-8.0" + +#endif // CUDA_CUDA_CONFIG_H_ diff --git a/third_party/typings.bzl b/third_party/typings.bzl deleted file mode 100644 index d0c9eddbb3f52803310caed8775840b5af8fbbfa..0000000000000000000000000000000000000000 --- a/third_party/typings.bzl +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard typing dependencies - -load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") - -def tensorboard_typings_workspace(): - filegroup_external( - name = "org_definitelytyped", - licenses = ["notice"], # MIT - sha256_urls = { - "b7da645f6e5555feb7aeede73775da0023ce2257df9c8e76c9159266035a9c0d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/chai/chai.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/chai/chai.d.ts", - ], - "177293828c7a206bf2a7f725753d51396d38668311aa37c96445f91bbf8128a7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts", # v3 - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts", # v3 - ], - "e4cd3d5de0eb3bc7b1063b50d336764a0ac82a658b39b5cf90511f489ffdee60": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/efd40e67ff323f7147651bdbef03c03ead7b1675/lodash/lodash.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/efd40e67ff323f7147651bdbef03c03ead7b1675/lodash/lodash.d.ts", - ], - "695a03dd2ccb238161d97160b239ab841562710e5c4e42886aefd4ace2ce152e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/mocha/mocha.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/mocha/mocha.d.ts", - ], - "513ccd9ee1c708881120eeacd56788fc3b3da8e5c6172b20324cebbe858803fe": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/708609e0764daeb5eb64104af7aca50c520c4e6e/sinon/sinon.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/708609e0764daeb5eb64104af7aca50c520c4e6e/sinon/sinon.d.ts", - ], - "44eba36339bd1c0792072b7b204ee926fe5ffe1e9e2da916e67ac55548e3668a": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/a872802c0c84ba98ff207d5e673a1fa867c67fd6/polymer/polymer.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/a872802c0c84ba98ff207d5e673a1fa867c67fd6/polymer/polymer.d.ts", - ], - "9453c3e6bae824e90758c3b38975c1ed77e6abd79bf513bcb08368fcdb14898e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/f5407eba29c04fb8387c86df27512bd055b195d2/threejs/three.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/f5407eba29c04fb8387c86df27512bd055b195d2/threejs/three.d.ts", - ], - "691756a6eb455f340c9e834de0d49fff269e7b8c1799c2454465dcd6a4435b80": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/46719185c564694c5583c4b7ad94dbb786ecad46/webcomponents.js/webcomponents.js.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/46719185c564694c5583c4b7ad94dbb786ecad46/webcomponents.js/webcomponents.js.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_array", - licenses = ["notice"], # MIT - sha256_urls = { - "61e7abb7b1f01fbcb0cab8cf39003392f422566209edd681fbd070eaa84ca000": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-array/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-array/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_axis", - licenses = ["notice"], # MIT - sha256_urls = { - "95f75c8dcc89850b2e72581d96a7b5f46ea4ac852f828893f141f14a597421f9": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-axis/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-axis/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_brush", - licenses = ["notice"], # MIT - sha256_urls = { - "a2738e693ce8a8640c2d29001e77582c9c361fd23bda44db471629866b60ada7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-brush/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-brush/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_chord", - licenses = ["notice"], # MIT - sha256_urls = { - "c54d24756eb6d744b31e538ad9bab3a75f6d54e2288b29cc72338d4a057d3e83": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-chord/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-chord/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_collection", - licenses = ["notice"], # MIT - sha256_urls = { - "f987667167b1d2970911247e325eb1c37ca0823646f81ccec837ae59039822f7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-collection/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-collection/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_color", - licenses = ["notice"], # MIT - sha256_urls = { - "9580c81f38ddcce7be0ac9bd3d0d083adebc34e17441709f90b9e4dcd1c19a56": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-color/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-color/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_dispatch", - licenses = ["notice"], # MIT - sha256_urls = { - "169f80b4cceca8e2e9ed384d81a5db0624cc01a26451dfb5a7e0cec6ea9cfb06": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dispatch/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dispatch/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_drag", - licenses = ["notice"], # MIT - sha256_urls = { - "08d35d139dde58c2722be98d718d01204fd6167d310f09b379e832f3c741489d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-drag/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-drag/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_dsv", - licenses = ["notice"], # MIT - sha256_urls = { - "62594d00cf9e4bb895339c8e56f64330e202a5eb2a0fa580a1f6e6336f2c93ce": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dsv/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dsv/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_ease", - licenses = ["notice"], # MIT - sha256_urls = { - "d1cf8f99b7bf758c2ba3c0a4ce553e151d4d9b4cf45a6e8bd0edec7ce90f725b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-ease/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-ease/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_force", - licenses = ["notice"], # MIT - sha256_urls = { - "288421e2008668d2076a4684657dd3d29b992832ef02c552981eb94a91042553": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-force/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-force/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_format", - licenses = ["notice"], # MIT - sha256_urls = { - "b42cb17e580c1fd0b64d478f7bd80ca806efaefda24426a833cf1f30a7275bca": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-format/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-format/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_hierarchy", - licenses = ["notice"], # MIT - sha256_urls = { - "a5683f5835d8716c6b89c075235078438cfab5897023ed720bfa492e244e969e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-hierarchy/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-hierarchy/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_interpolate", - licenses = ["notice"], # MIT - sha256_urls = { - "590a71b741323ac3139b333ec8b743e24717fdd5b32bcff48ee521162a9dfe1c": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-interpolate/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-interpolate/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_path", - licenses = ["notice"], # MIT - sha256_urls = { - "96f35ba041bcaa265e2b373ee675177410d44d31c980e4f7fbeefd4bcba15b00": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-path/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-path/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_polygon", - licenses = ["notice"], # MIT - sha256_urls = { - "ce453451e8105cac6a4f4a4263ca2142ebb4bf442e342f470a81da691f220fcb": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-polygon/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-polygon/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_quadtree", - licenses = ["notice"], # MIT - sha256_urls = { - "238e278f1be5d6985a19800800cffee80f81199f71d848e3bbc288d1791a6f90": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-quadtree/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-quadtree/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_queue", - licenses = ["notice"], # MIT - sha256_urls = { - "e6ae19aad83495475653578de64fb9d6bf9764eda6c84d70f7935ec84bcc482e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-queue/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-queue/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_random", - licenses = ["notice"], # MIT - sha256_urls = { - "d31b92ed86c23ec0a4776f99fa81ff033c95b96c8304d8aa9baf3b94af779aa8": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-random/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-random/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_request", - licenses = ["notice"], # MIT - sha256_urls = { - "44bb7b07d977028e6567540a3303b06fc9b33fb0960bc75c520e0733c840d89f": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-request/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-request/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_scale", - licenses = ["notice"], # MIT - sha256_urls = { - "02ce7c644ba34bd1abb84da2e832f248b048b6a23812be4365bd837f186c9f1f": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-scale/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-scale/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_selection", - licenses = ["notice"], # MIT - sha256_urls = { - "699043ddb28dfa5e46d87bc6a24cfc6d604237f298259d3fb3c7066e05e8c86e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-selection/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-selection/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_shape", - licenses = ["notice"], # MIT - sha256_urls = { - "62668a7aaaf6232762b544f9f89c0f557ca7cfb0cd343a358dda7ecbe26f5739": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-shape/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-shape/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_time", - licenses = ["notice"], # MIT - sha256_urls = { - "0502490ce682fd9265fb1d5d693ce6cd82e3b05e5f5ee3433731266ecb03d5fc": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-time/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-time/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_timer", - licenses = ["notice"], # MIT - sha256_urls = { - "6f191f9aea704aa64b1defa40dfdff1447a6e6bb815feff1660f894500a9c94d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-timer/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-timer/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_transition", - licenses = ["notice"], # MIT - sha256_urls = { - "a0a7c0c9bfb5c7d6d9d22a8d16b4484b66d13f2ed226954037546cb3da4098ba": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-transition/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-transition/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_voronoi", - licenses = ["notice"], # MIT - sha256_urls = { - "c6bd5f229f915151d0ef678fe50b1aa6a62334ea0a8c6fc0effbac9f7032efc7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-voronoi/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-voronoi/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_zoom", - licenses = ["notice"], # MIT - sha256_urls = { - "a25dc17fbd304cf7a0e5e7bbb8339c930d464eb40c4d6e5f839ce9c0191f4110": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-zoom/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-zoom/index.d.ts", - ], - }, - ) diff --git a/third_party/werkzeug.BUILD b/third_party/werkzeug.BUILD deleted file mode 100644 index 72a1402030d150c21b5d43261a4d5e2c0f1bce91..0000000000000000000000000000000000000000 --- a/third_party/werkzeug.BUILD +++ /dev/null @@ -1,14 +0,0 @@ -# Description: -# Werkzeug provides utilities for making WSGI applications - -licenses(["notice"]) # BSD 3-Clause - -exports_files(["LICENSE"]) - -# Note: this library includes test code. Consider creating a testonly target. -py_library( - name = "werkzeug", - srcs = glob(["werkzeug/*.py"]), - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], -) diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD index 279e6395b03a3c86b4b3fe25958ebafa4cb75062..85096688914a1598ef1d51b71721d860398947cb 100644 --- a/third_party/zlib.BUILD +++ b/third_party/zlib.BUILD @@ -2,6 +2,18 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # BSD/MIT-like license (for zlib) +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "windows_msvc", + values = {"cpu": "x64_windows_msvc"}, + visibility = ["//visibility:public"], +) + cc_library( name = "zlib", srcs = [ @@ -32,9 +44,13 @@ cc_library( "zutil.h", ], hdrs = ["zlib.h"], - copts = [ - "-Wno-shift-negative-value", - "-Wno-implicit-function-declaration", - ], + copts = select({ + ":windows": [], + ":windows_msvc": [], + "//conditions:default": [ + "-Wno-shift-negative-value", + "-Wno-implicit-function-declaration", + ], + }), includes = ["."], ) diff --git a/tools/bazel.rc b/tools/bazel.rc index e67a290cf40ca7f688dfdb03210786c8c85abe48..414ddf2e475da051cad4a4534a3a0ca955229997 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -11,6 +11,9 @@ build:mkl --define=using_mkl=true build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl --define=using_sycl=true +build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain +build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE + build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl_asan --define=using_sycl=true --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address